diff --git a/.azuredevops/rocm-ci.yml b/.azuredevops/rocm-ci.yml index 4161c2d5a4e54e731a356656bbff8864326c7fee..b37b8cc27fcc2b5f7419dc36b854fe8962ba3734 100644 --- a/.azuredevops/rocm-ci.yml +++ b/.azuredevops/rocm-ci.yml @@ -14,6 +14,7 @@ trigger: branches: include: - develop + - amd-develop paths: exclude: - .github diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 459315e58b766043355046569379ab96500a3449..f6ab388e2a509281e6c595b0c58f28cfb8da979c 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,8 +1,8 @@ -* @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk +* @junliume @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj # Documentation files -docs/ @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk -*.md @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk -*.rst @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk -.readthedocs.yaml @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk +docs/ @ROCm/rocm-documentation @junliume @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj +*.md @ROCm/rocm-documentation @junliume @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj +*.rst @ROCm/rocm-documentation @junliume @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj +.readthedocs.yaml @ROCm/rocm-documentation @junliume @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj # Header directory for Doxygen documentation -library/include/ @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk +library/include/ @ROCm/rocm-documentation @junliume @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..56f2acee71053106143afd49156f4eff0d76193d --- /dev/null +++ b/.github/CONTRIBUTING.md @@ -0,0 +1,10 @@ +We'd love for you to contribute to our source code! + +Some helpful links: + +- [Code of Conduct guidelines](https://www.contributor-covenant.org/version/2/1/code_of_conduct/code_of_conduct.txt) +- [New issue guidelines](https://github.com/rocm/composable_kernel/blob/develop/.github/ISSUE_TEMPLATE.md) +- [Submitting a pull request guidelines](https://github.com/rocm/composable_kernel/blob/develop/.github/PULL_REQUEST_TEMPLATE.md) +- [Maintainers](https://github.com/rocm/composable_kernel/blob/develop/CONTRIBUTORS.md) +- [General information](https://github.com/rocm/composable_kernel/blob/develop/README.md) +- [ROCm documentation](https://rocm.docs.amd.com/en/latest/how-to/llm-fine-tuning-optimization/optimizing-with-composable-kernel.html) \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md new file mode 100644 index 0000000000000000000000000000000000000000..263cc3480dea6eca22a8a9039a0e3e64c25d56b5 --- /dev/null +++ b/.github/ISSUE_TEMPLATE.md @@ -0,0 +1,14 @@ +When creating an issue, please check if a similar issue already exists. + +### When reporting a bug, please include: +- [ ] A descriptive title +- [ ] An isolated way to reproduce the behavior (preferably a docker container with a repro) +- [ ] ROCm version, clang version, Composable Kernel commit pin +- [ ] Environment variables +- [ ] The behavior you expect to see, and the behavior you actually see + +### When requesting a feature, please include: +- [ ] A descriptive title +- [ ] A detailed description of the problem you are trying to solve +- [ ] An overview of the suggested solution +- [ ] Explanation why the solution is an improvement \ No newline at end of file diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000000000000000000000000000000000000..8a988ad1c9e4dc50b57feb2bd3eed542e2801a30 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,20 @@ +## Proposed changes + +Please describe the motivation behind the pull request, whether it enables a new feature or fixes a bug. If there are associated pull requests or issues, please link them to the pull request. + +## Checklist + +Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. + +- [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally +- [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. +- [ ] I have added inline documentation which enables the maintainers with understanding the motivation +- [ ] I have removed the stale documentation which is no longer relevant after this pull request +- [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request +- [ ] I have run `clang-format` on all changed files +- [ ] Any dependent changes have been merged + +## Discussion + +If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered + diff --git a/CMakeLists.txt b/CMakeLists.txt index b28a6d91274d0c7c7394edc594d50435be8541e2..1fe1bc91d520e602e387664ca20948076e07a44c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -97,9 +97,20 @@ if(DL_KERNELS) add_definitions(-DDL_KERNELS) set(CK_ENABLE_DL_KERNELS "ON") endif() +if(DPP_KERNELS) + add_definitions(-DDPP_KERNELS) + set(CK_ENABLE_DPP_KERNELS "ON") +endif() option(CK_USE_CODEGEN "Enable codegen library" OFF) if(CK_USE_CODEGEN) - add_definitions(-DCK_USE_CODEGEN) + add_definitions(-DCK_USE_CODEGEN) +endif() + +option(CK_TIME_KERNEL "Enable kernel time tracking" ON) +if(CK_TIME_KERNEL) + add_definitions(-DCK_TIME_KERNEL=1) +else() + add_definitions(-DCK_TIME_KERNEL=0) endif() include(getopt) @@ -183,18 +194,38 @@ message("Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}") if (SUPPORTED_GPU_TARGETS MATCHES "gfx9") message("Enabling XDL instances") add_definitions(-DCK_USE_XDL) + set(CK_USE_XDL "ON") endif() -if (SUPPORTED_GPU_TARGETS MATCHES "gfx94") - message("Enabling FP8 gemms in ckProfiler") +if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") + message("Enabling FP8 gemms on native architectures") add_definitions(-DCK_USE_GFX94) + set(CK_USE_GFX94 "ON") +endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx95") + add_definitions(-DCK_USE_AMD_MFMA_GFX950) endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") message("Enabling WMMA instances") add_definitions(-DCK_USE_WMMA) + set(CK_USE_WMMA "ON") endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950") + add_definitions(-DCK_USE_OCP_FP8) + set(CK_USE_OCP_FP8 "ON") +endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx94") + add_definitions(-DCK_USE_FNUZ_FP8) + set(CK_USE_FNUZ_FP8 "ON") +endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx950") + add_definitions(-DCK_USE_NATIVE_MX_SUPPORT) + set(CK_USE_NATIVE_MX_SUPPORT "ON") +endif() + option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908")) add_definitions(-DCK_USE_FP8_ON_UNSUPPORTED_ARCH) + set(CK_USE_FP8_ON_UNSUPPORTED_ARCH "ON") endif() # CK config file to record supported datatypes, etc. @@ -516,7 +547,13 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERS add_compile_options(-fdiagnostics-color=always) endif() +# make check runs the entire set of examples and tests add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) +# make smoke runs the tests and examples that runs within 30 seconds on gfx90a +add_custom_target(smoke COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "SMOKE_TEST") +# make regression runs the tests and examples that runs for more 30 seconds on gfx90a +add_custom_target(regression COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "REGRESSION_TEST") + file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp") file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*) @@ -572,7 +609,7 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS) ) add_subdirectory(example) if(BUILD_TESTING) - add_subdirectory(test) + add_subdirectory(test) endif() endif() diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index cdce5a46309f59b27d8c658e785f70bf743527db..8ef5c2b726cd4a93ce8278ca3fd093c7da144d5a 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -1,3 +1,4 @@ +[Back to the main page](./README.md) # Composable Kernel Developers and Contributors This is the list of developers and contributors to Composable Kernel library diff --git a/Dockerfile b/Dockerfile index 791d1d9f3ab8143e74ead44187e7a8bb9b764f8f..2873a8500b6d95e36a3fe2856d28eb20e3bf8fa3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,37 +1,28 @@ -FROM ubuntu:20.04 +FROM ubuntu:22.04 ARG DEBIAN_FRONTEND=noninteractive -ARG ROCMVERSION=6.2 +ARG ROCMVERSION=6.3 ARG compiler_version="" ARG compiler_commit="" ARG CK_SCCACHE="" - -RUN set -xe - ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/ -RUN useradd -rm -d /home/jenkins -s /bin/bash -u 1004 jenkins -# Add rocm repository -RUN chmod 1777 /tmp -RUN apt-get update -RUN apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl - ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn -RUN curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg -RUN if [ "$ROCMVERSION" != "6.3" ]; then \ - sh -c "wget https://repo.radeon.com/amdgpu-install/$ROCMVERSION/ubuntu/focal/amdgpu-install_6.2.60200-1_all.deb --no-check-certificate" && \ - apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.2.60200-1_all.deb && \ +# Add rocm repository +RUN set -xe && \ + useradd -rm -d /home/jenkins -s /bin/bash -u 1004 jenkins && \ + apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl && \ + curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg + +RUN if [ "$ROCMVERSION" != "6.4" ]; then \ + sh -c "wget https://repo.radeon.com/amdgpu-install/$ROCMVERSION/ubuntu/focal/amdgpu-install_6.3.60300-1_all.deb --no-check-certificate" && \ + apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.3.60300-1_all.deb && \ wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \ sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \ - elif [ "$ROCMVERSION" = "6.3" ] && [ "$compiler_version" = "rc1" ]; then \ - sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.3-20.04-1_all.deb --no-check-certificate" && \ - apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog libpopt0 rsync && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.3-20.04-1_all.deb && \ - sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.3 rel-20 > /etc/apt/sources.list.d/rocm-build.list' && \ - amdgpu-repo --amdgpu-build=2074281; \ fi -RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" -RUN amdgpu-install -y --usecase=rocm --no-dkms +RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" && \ + amdgpu-install -y --usecase=rocm --no-dkms ## Sccache binary built from source for ROCm, only install if CK_SCCACHE is defined ARG SCCACHE_REPO_URL=http://compute-artifactory.amd.com/artifactory/rocm-generic-experimental/rocm-sccache @@ -57,6 +48,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- libnuma-dev \ libpthread-stubs0-dev \ llvm-amdgpu \ + mpich \ net-tools \ pkg-config \ python \ @@ -72,72 +64,52 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- nano \ zlib1g-dev \ zip \ + libzstd-dev \ openssh-server \ clang-format-12 \ kmod && \ apt-get clean && \ - rm -rf /var/lib/apt/lists/* + rm -rf /var/lib/apt/lists/* && \ + rm -rf amdgpu-install* && \ +# Remove unnecessary rocm components that take a lot of space + apt-get remove -y rocblas rocfft rocsparse composablekernel-dev hipblaslt -# hipTensor requires rocm-llvm-dev for rocm versions > 6.0.1 -RUN if [ "$ROCMVERSION" = "6.1" ]; then \ - sh -c "apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated rocm-llvm-dev"; \ - fi # Update the cmake to version 3.27.5 -RUN pip install --upgrade cmake==3.27.5 - +RUN pip install --upgrade cmake==3.27.5 && \ #Install latest ccache -RUN git clone https://github.com/ccache/ccache.git && \ - cd ccache && mkdir build && cd build && cmake .. && make install - + git clone https://github.com/ccache/ccache.git && \ + cd ccache && mkdir build && cd build && cmake .. && make install && \ #Install ninja build tracing tools -RUN wget -qO /usr/local/bin/ninja.gz https://github.com/ninja-build/ninja/releases/latest/download/ninja-linux.zip -RUN gunzip /usr/local/bin/ninja.gz -RUN chmod a+x /usr/local/bin/ninja -RUN git clone https://github.com/nico/ninjatracing.git - + cd / && \ + wget -qO /usr/local/bin/ninja.gz https://github.com/ninja-build/ninja/releases/latest/download/ninja-linux.zip && \ + gunzip /usr/local/bin/ninja.gz && \ + chmod a+x /usr/local/bin/ninja && \ + git clone https://github.com/nico/ninjatracing.git && \ #Install latest cppcheck -RUN git clone https://github.com/danmar/cppcheck.git && \ - cd cppcheck && mkdir build && cd build && cmake .. && cmake --build . -WORKDIR / - -# Setup ubsan environment to printstacktrace -RUN ln -s /usr/bin/llvm-symbolizer-3.8 /usr/local/bin/llvm-symbolizer -ENV UBSAN_OPTIONS=print_stacktrace=1 - + git clone https://github.com/danmar/cppcheck.git && \ + cd cppcheck && mkdir build && cd build && cmake .. && cmake --build . && \ + cd / && \ # Install an init system -RUN wget https://github.com/Yelp/dumb-init/releases/download/v1.2.0/dumb-init_1.2.0_amd64.deb -RUN dpkg -i dumb-init_*.deb && rm dumb-init_*.deb - -ARG PREFIX=/opt/rocm + wget https://github.com/Yelp/dumb-init/releases/download/v1.2.0/dumb-init_1.2.0_amd64.deb && \ + dpkg -i dumb-init_*.deb && rm dumb-init_*.deb && \ # Install packages for processing the performance results -RUN pip3 install --upgrade pip -RUN pip3 install sqlalchemy==1.4.46 -RUN pip3 install pymysql -RUN pip3 install pandas==2.0.3 -RUN pip3 install setuptools-rust -RUN pip3 install sshtunnel==0.4.0 -# Setup ubsan environment to printstacktrace -ENV UBSAN_OPTIONS=print_stacktrace=1 - -ENV LC_ALL=C.UTF-8 -ENV LANG=C.UTF-8 -RUN groupadd -f render - + pip3 install --upgrade pip && \ + pip3 install --upgrade pytest sqlalchemy==2.0.36 pymysql pandas==2.2.3 setuptools-rust setuptools>=75 sshtunnel==0.4.0 && \ +# Add render group + groupadd -f render && \ # Install the new rocm-cmake version -RUN git clone -b master https://github.com/ROCm/rocm-cmake.git && \ - cd rocm-cmake && mkdir build && cd build && \ - cmake .. && cmake --build . && cmake --build . --target install + git clone -b master https://github.com/ROCm/rocm-cmake.git && \ + cd rocm-cmake && mkdir build && cd build && \ + cmake .. && cmake --build . && cmake --build . --target install WORKDIR / - +# Add alternative compilers, if necessary ENV compiler_version=$compiler_version ENV compiler_commit=$compiler_commit -RUN sh -c "echo compiler version = '$compiler_version'" -RUN sh -c "echo compiler commit = '$compiler_commit'" - -ARG DISABLE_CACHE=0 +RUN sh -c "echo compiler version = '$compiler_version'" && \ + sh -c "echo compiler commit = '$compiler_commit'" -RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" = "" ]; then \ +RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" = "" ]; then \ git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ cd llvm-project && mkdir build && cd build && \ cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ @@ -145,16 +117,10 @@ RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd else echo "using the release compiler"; \ fi -RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" != "" ]; then \ +RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" != "" ]; then \ git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \ cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ make -j 8 ; \ else echo "using the release compiler"; \ fi - -#clean-up the deb package -RUN sh -c "rm -rf amdgpu-install*" - -#ENV HIP_CLANG_PATH='/llvm-project/build/bin' -#RUN sh -c "echo HIP_CLANG_PATH = '$HIP_CLANG_PATH'" diff --git a/Dockerfile.compiler b/Dockerfile.compiler new file mode 100644 index 0000000000000000000000000000000000000000..a22103b96b30f22eb5c8ddd0edd21ee6b0e737a2 --- /dev/null +++ b/Dockerfile.compiler @@ -0,0 +1,26 @@ +ARG BASE_DOCKER="rocm/composable_kernel:ck_ub22.04_rocm6.3" +FROM $BASE_DOCKER +ARG compiler_version="" +ARG compiler_commit="" + +# Add alternative compilers, if necessary +ENV compiler_version=$compiler_version +ENV compiler_commit=$compiler_commit +RUN sh -c "echo compiler version = '$compiler_version'" && \ + sh -c "echo compiler commit = '$compiler_commit'" + +RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" = "" ]; then \ + git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ + cd llvm-project && mkdir build && cd build && \ + cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ + make -j 16 ; \ + else echo "using the release compiler"; \ + fi + +RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" != "" ]; then \ + git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ + cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \ + cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ + make -j 16 ; \ + else echo "using the release compiler"; \ + fi diff --git a/Jenkinsfile b/Jenkinsfile index b79b2045b0750b2244d643df72bbeaaebd613f43..835b7e724f1c40818ef38c5b31f54f2b9dd92fb4 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -32,41 +32,43 @@ def runShell(String command){ return (output != "") } -def getDockerImageName(){ +def getBaseDockerImageName(){ def img if (params.USE_CUSTOM_DOCKER != ""){ img = "${params.USE_CUSTOM_DOCKER}" } else{ - if (params.ROCMVERSION != "6.3"){ - if (params.COMPILER_VERSION == "") { - img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}" - } - else{ - if (params.COMPILER_COMMIT == ""){ - img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}" - } - else{ - def commit = "${params.COMPILER_COMMIT}"[0..6] - img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}_${commit}" - } - } + def ROCM_numeric = "${params.ROCMVERSION}" as float + if ( ROCM_numeric < 6.4 ){ + img = "${env.CK_DOCKERHUB}:ck_ub22.04_rocm${params.ROCMVERSION}" + } + else{ + img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub22.04_rocm${params.ROCMVERSION}" + } + } + return img +} + +def getDockerImageName(){ + def img + def base_name = getBaseDockerImageName() + if (params.USE_CUSTOM_DOCKER != ""){ + img = "${params.USE_CUSTOM_DOCKER}" } else{ if (params.COMPILER_VERSION == "") { - img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub20.04_rocm${params.ROCMVERSION}" + img = "${base_name}" } else{ if (params.COMPILER_COMMIT == ""){ - img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}" + img = "${base_name}_${params.COMPILER_VERSION}" } else{ def commit = "${params.COMPILER_COMMIT}"[0..6] - img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}_${commit}" + img = "${base_name}_${params.COMPILER_VERSION}_${commit}" } } } - } return img } @@ -131,17 +133,21 @@ def buildDocker(install_prefix){ env.DOCKER_BUILDKIT=1 checkout scm def image_name = getDockerImageName() + def base_image_name = getBaseDockerImageName() echo "Building Docker for ${image_name}" - def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' --build-arg DISABLE_CACHE='git rev-parse ${params.COMPILER_VERSION}' " - if(params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ - dockerArgs = dockerArgs + " --no-cache " + def dockerArgs = "--build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " + if(params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){ + dockerArgs = dockerArgs + " --no-cache --build-arg BASE_DOCKER='${base_image_name}' -f Dockerfile.compiler . " + } + else{ + dockerArgs = dockerArgs + " -f Dockerfile . " } echo "Build Args: ${dockerArgs}" try{ if(params.BUILD_DOCKER){ //force building the new docker if that parameter is true echo "Building image: ${image_name}" - retimage = docker.build("${image_name}", dockerArgs + ' .') + retimage = docker.build("${image_name}", dockerArgs) withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { retimage.push() } @@ -320,14 +326,38 @@ def cmake_build(Map conf=[:]){ if (package_build == true && (env.BRANCH_NAME == "develop" || env.BRANCH_NAME == "amd-master")) { archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true } + //check the node gpu architecture + def arch_type = 0 + sh 'rocminfo | tee rocminfo.log' + if ( runShell('grep -n "gfx90a" rocminfo.log') ){ + arch_type = 1 + } + else if ( runShell('grep -n "gfx942" rocminfo.log') ) { + arch_type = 2 + } if (params.RUN_CK_TILE_FMHA_TESTS){ try{ - archiveArtifacts "perf_fmha_fwd_*.log" - archiveArtifacts "perf_fmha_bwd_*.log" - stash name: "perf_fmha_fwd_gfx942.log" - stash name: "perf_fmha_bwd_gfx942.log" - stash name: "perf_fmha_fwd_gfx90a.log" - stash name: "perf_fmha_bwd_gfx90a.log" + archiveArtifacts "perf_fmha_*.log" + if (arch_type == 1){ + stash includes: "perf_fmha_**_gfx90a.log", name: "perf_fmha_log_gfx90a" + } + else if (arch_type == 2){ + stash includes: "perf_fmha_**_gfx942.log", name: "perf_fmha_log_gfx942" + } + } + catch(Exception err){ + echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing." + } + } + if (params.RUN_CK_TILE_GEMM_TESTS){ + try{ + archiveArtifacts "perf_tile_gemm_*.log" + if (arch_type == 1){ + stash includes: "perf_tile_gemm_**_fp16_gfx90a.log", name: "perf_tile_gemm_log_gfx90a" + } + else if (arch_type == 2){ + stash includes: "perf_tile_gemm_**_fp16_gfx942.log", name: "perf_tile_gemm_log_gfx942" + } } catch(Exception err){ echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing." @@ -353,12 +383,12 @@ def buildHipClangJob(Map conf=[:]){ def prefixpath = conf.get("prefixpath", "/opt/rocm") // Jenkins is complaining about the render group - def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts="-u root --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3') @@ -373,7 +403,7 @@ def buildHipClangJob(Map conf=[:]){ gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { - timeout(time: 48, unit: 'HOURS') + timeout(time: 20, unit: 'HOURS') { cmake_build(conf) } @@ -402,128 +432,6 @@ def buildHipClangJobAndReboot(Map conf=[:]){ } } -def runCKProfiler(Map conf=[:]){ - show_node_info() - - env.HSA_ENABLE_SDMA=0 - checkout scm - - def image = getDockerImageName() - def prefixpath = conf.get("prefixpath", "/opt/rocm") - - // Jenkins is complaining about the render group - def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" - if (conf.get("enforce_xnack_on", false)) { - dockerOpts = dockerOpts + " --env HSA_XNACK=1 " - } - def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3') - def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3') - dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} " - echo "Docker flags: ${dockerOpts}" - - def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - - def variant = env.STAGE_NAME - def retimage - - gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { - try { - (retimage, image) = getDockerImage(conf) - withDockerContainer(image: image, args: dockerOpts) { - timeout(time: 5, unit: 'MINUTES'){ - sh 'rocminfo | tee rocminfo.log' - if ( !runShell('grep -n "gfx" rocminfo.log') ){ - throw new Exception ("GPU not found") - } - else{ - echo "GPU is OK" - } - } - } - } - catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){ - echo "The job was cancelled or aborted" - throw e - } - - withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { - timeout(time: 24, unit: 'HOURS') - { - sh """ - rm -rf build - mkdir build - """ - dir("build"){ - unstash 'ckProfiler.tar.gz' - sh 'tar -xvf ckProfiler.tar.gz' - } - - dir("script"){ - if (params.RUN_FULL_QA){ - sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}" - archiveArtifacts "perf_gemm.log" - archiveArtifacts "perf_resnet50_N256.log" - archiveArtifacts "perf_resnet50_N4.log" - archiveArtifacts "perf_batched_gemm.log" - archiveArtifacts "perf_grouped_gemm.log" - archiveArtifacts "perf_grouped_conv_fwd.log" - archiveArtifacts "perf_grouped_conv_bwd_data.log" - archiveArtifacts "perf_grouped_conv_bwd_weight.log" - archiveArtifacts "perf_gemm_bilinear.log" - archiveArtifacts "perf_reduction.log" - archiveArtifacts "perf_splitK_gemm.log" - archiveArtifacts "perf_onnx_gemm.log" - archiveArtifacts "perf_mixed_gemm.log" - // stash perf files to master - stash name: "perf_gemm.log" - stash name: "perf_resnet50_N256.log" - stash name: "perf_resnet50_N4.log" - stash name: "perf_batched_gemm.log" - stash name: "perf_grouped_gemm.log" - stash name: "perf_grouped_conv_fwd.log" - stash name: "perf_grouped_conv_bwd_data.log" - stash name: "perf_grouped_conv_bwd_weight.log" - stash name: "perf_gemm_bilinear.log" - stash name: "perf_reduction.log" - stash name: "perf_splitK_gemm.log" - stash name: "perf_onnx_gemm.log" - stash name: "perf_mixed_gemm.log" - //we will process results on the master node - } - else{ - sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}" - archiveArtifacts "perf_gemm.log" - archiveArtifacts "perf_resnet50_N256.log" - archiveArtifacts "perf_resnet50_N4.log" - // stash perf files to master - stash name: "perf_gemm.log" - stash name: "perf_resnet50_N256.log" - stash name: "perf_resnet50_N4.log" - //we will process the results on the master node - } - } - } - } - } - return retimage -} - -def runPerfTest(Map conf=[:]){ - try{ - runCKProfiler(conf) - } - catch(e){ - echo "throwing error exception in performance tests" - echo 'Exception occurred: ' + e.toString() - throw e - } - finally{ - if (!conf.get("no_reboot", false)) { - reboot() - } - } -} - def Build_CK(Map conf=[:]){ show_node_info() @@ -544,12 +452,12 @@ def Build_CK(Map conf=[:]){ def prefixpath = conf.get("prefixpath", "/opt/rocm") // Jenkins is complaining about the render group - def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts="-u root --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } if(params.BUILD_LEGACY_OS){ @@ -567,7 +475,7 @@ def Build_CK(Map conf=[:]){ try { (retimage, image) = getDockerImage(conf) withDockerContainer(image: image, args: dockerOpts) { - timeout(time: 5, unit: 'MINUTES'){ + timeout(time: 2, unit: 'MINUTES'){ sh 'rocminfo | tee rocminfo.log' if ( !runShell('grep -n "gfx" rocminfo.log') ){ throw new Exception ("GPU not found") @@ -583,36 +491,102 @@ def Build_CK(Map conf=[:]){ throw e } withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { - timeout(time: 24, unit: 'HOURS') + timeout(time: 20, unit: 'HOURS') { //check whether to run performance tests on this node - def do_perf_tests = 0 + def arch_type = 0 sh 'rocminfo | tee rocminfo.log' - if ( runShell('grep -n "gfx1030" rocminfo.log') || runShell('grep -n "gfx1101" rocminfo.log') || runShell('grep -n "gfx1201" rocminfo.log') || runShell('grep -n "gfx942" rocminfo.log') ){ - do_perf_tests = 1 - echo "Stash profiler and run performance tests" + if ( runShell('grep -n "gfx90a" rocminfo.log') ){ + arch_type = 1 + } + else if ( runShell('grep -n "gfx942" rocminfo.log') ) { + arch_type = 2 + } + else if ( runShell('grep -n "gfx1030" rocminfo.log') ) { + arch_type = 3 + } + else if ( runShell('grep -n "gfx1101" rocminfo.log') ) { + arch_type = 4 + } + else if ( runShell('grep -n "gfx1201" rocminfo.log') ) { + arch_type = 5 } cmake_build(conf) + if ( !params.BUILD_LEGACY_OS && arch_type == 1 ){ + echo "Run inductor codegen tests" + sh """ + pip install --verbose . + pytest python/test/test_gen_instances.py + """ + } dir("build"){ - //run tests and examples - //sh 'make -j check' - if (params.RUN_PERFORMANCE_TESTS && do_perf_tests == 0 ){ - //we only need the ckProfiler to run the performance tests, so we pack and stash it - //do not stash profiler on nodes where we don't need to run performance tests - sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler' - stash name: "ckProfiler.tar.gz" - } - if (params.RUN_FULL_QA && do_perf_tests == 0 ){ - // build deb packages for all gfx9 targets and prepare to export + if (params.RUN_FULL_QA && arch_type == 1 ){ + // build deb packages for all gfx9 targets on gfx90a system and prepare to export + echo "Build ckProfiler package" sh 'make -j package' archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb' - archiveArtifacts artifacts: 'composablekernel-tests_*.deb' sh 'mv composablekernel-ckprofiler_*.deb ckprofiler_0.2.0_amd64.deb' - stash name: "ckprofiler_0.2.0_amd64.deb" + stash includes: "ckprofiler_0.2.0_amd64.deb", name: "ckprofiler_0.2.0_amd64.deb" + } + } + // run performance tests, stash the logs, results will be processed on the master node + dir("script"){ + if (params.RUN_PERFORMANCE_TESTS){ + if (params.RUN_FULL_QA && arch_type == 1){ + // run full tests on gfx90a + echo "Run full performance tests" + sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}" + archiveArtifacts "perf_gemm.log" + archiveArtifacts "perf_resnet50_N256.log" + archiveArtifacts "perf_resnet50_N4.log" + archiveArtifacts "perf_batched_gemm.log" + archiveArtifacts "perf_grouped_gemm.log" + archiveArtifacts "perf_grouped_conv_fwd.log" + archiveArtifacts "perf_grouped_conv_bwd_data.log" + archiveArtifacts "perf_grouped_conv_bwd_weight.log" + archiveArtifacts "perf_gemm_bilinear.log" + archiveArtifacts "perf_reduction.log" + archiveArtifacts "perf_splitK_gemm.log" + archiveArtifacts "perf_onnx_gemm.log" + archiveArtifacts "perf_mixed_gemm.log" + stash includes: "perf_**.log", name: "perf_log" + } + else if ( arch_type == 1 ){ + // run standard tests on gfx90a + echo "Run performance tests" + sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}" + archiveArtifacts "perf_gemm.log" + archiveArtifacts "perf_onnx_gemm.log" + archiveArtifacts "perf_resnet50_N256.log" + archiveArtifacts "perf_resnet50_N4.log" + stash includes: "perf_**.log", name: "perf_log" + } + // disable performance tests on gfx1030 for now. + //else if ( arch_type == 3){ + // run basic tests on gfx1030 + // echo "Run gemm performance tests" + // sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx10" + // archiveArtifacts "perf_onnx_gemm_gfx10.log" + // stash includes: "perf_onnx_gemm_gfx10.log", name: "perf_log_gfx10" + //} + else if ( arch_type == 4){ + // run basic tests on gfx11 + echo "Run gemm performance tests" + sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx11" + archiveArtifacts "perf_onnx_gemm_gfx11.log" + stash includes: "perf_onnx_gemm_gfx11.log", name: "perf_log_gfx11" + } + else if ( arch_type == 5 ){ + // run basic tests on gfx12 + echo "Run gemm performance tests" + sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx12" + archiveArtifacts "perf_onnx_gemm_gfx12.log" + stash includes: "perf_onnx_gemm_gfx12.log", name: "perf_log_gfx12" + } } } - if (params.hipTensor_test && do_perf_tests == 0 ){ - //build and test hipTensor + if (params.hipTensor_test && arch_type == 1 ){ + // build and test hipTensor on gfx90a node sh """#!/bin/bash rm -rf "${params.hipTensor_branch}".zip rm -rf hipTensor-"${params.hipTensor_branch}" @@ -625,11 +599,9 @@ def Build_CK(Map conf=[:]){ ls -ltr CC=hipcc CXX=hipcc cmake -Bbuild . -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install" cmake --build build -- -j + ctest --test-dir build """ } - dir("hipTensor-${params.hipTensor_branch}/build"){ - sh 'ctest' - } } } } @@ -679,44 +651,51 @@ def process_results(Map conf=[:]){ } withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { - timeout(time: 1, unit: 'HOURS'){ + timeout(time: 15, unit: 'MINUTES'){ try{ dir("script"){ if (params.RUN_CK_TILE_FMHA_TESTS){ try{ - unstash "perf_fmha_fwd_gfx942.log" - unstash "perf_fmha_bwd_gfx942.log" - unstash "perf_fmha_fwd_gfx90a.log" - unstash "perf_fmha_bwd_gfx90a.log" + unstash "perf_fmha_log_gfx942" + unstash "perf_fmha_log_gfx90a" } catch(Exception err){ echo "could not locate the FMHA performance logs: ${err.getMessage()}." } } + if (params.RUN_CK_TILE_GEMM_TESTS){ + try{ + unstash "perf_tile_gemm_log_gfx942" + unstash "perf_tile_gemm_log_gfx90a" + } + catch(Exception err){ + echo "could not locate the GEMM performance logs: ${err.getMessage()}." + } + } if (params.RUN_FULL_QA){ // unstash perf files to master unstash "ckprofiler_0.2.0_amd64.deb" sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no ckprofiler_0.2.0_amd64.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/" - unstash "perf_gemm.log" - unstash "perf_resnet50_N256.log" - unstash "perf_resnet50_N4.log" - unstash "perf_batched_gemm.log" - unstash "perf_grouped_gemm.log" - unstash "perf_grouped_conv_fwd.log" - unstash "perf_grouped_conv_bwd_data.log" - unstash "perf_grouped_conv_bwd_weight.log" - unstash "perf_gemm_bilinear.log" - unstash "perf_reduction.log" - unstash "perf_splitK_gemm.log" - unstash "perf_onnx_gemm.log" - unstash "perf_mixed_gemm.log" + unstash "perf_log" + try{ + unstash "perf_log_gfx11" + unstash "perf_log_gfx12" + } + catch(Exception err){ + echo "could not locate the GEMM gfx11/gfx12 performance logs: ${err.getMessage()}." + } sh "./process_qa_data.sh" } else{ // unstash perf files to master - unstash "perf_gemm.log" - unstash "perf_resnet50_N256.log" - unstash "perf_resnet50_N4.log" + unstash "perf_log" + try{ + unstash "perf_log_gfx11" + unstash "perf_log_gfx12" + } + catch(Exception err){ + echo "could not locate the GEMM gfx11/gfx12 performance logs: ${err.getMessage()}." + } sh "./process_perf_data.sh" } } @@ -734,10 +713,10 @@ def process_results(Map conf=[:]){ } //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.2;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true - 0 21 * * * % ROCMVERSION=6.2;hipTensor_test=true;RUN_CODEGEN_TESTS=true - 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true - 0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.3;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true + 0 21 * * * % ROCMVERSION=6.3;hipTensor_test=true;RUN_CODEGEN_TESTS=true + 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true + 0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true 0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false 0 13 * * * % BUILD_LEGACY_OS=true''' : "" @@ -760,12 +739,12 @@ pipeline { description: 'If you want to use a custom docker image, please specify it here (default: leave blank).') string( name: 'ROCMVERSION', - defaultValue: '6.2', - description: 'Specify which ROCM version to use: 6.2 (default).') + defaultValue: '6.3', + description: 'Specify which ROCM version to use: 6.3 (default).') string( name: 'COMPILER_VERSION', defaultValue: '', - description: 'Specify which version of compiler to use: release, amd-staging, amd-mainline-open, or leave blank (default).') + description: 'Specify which version of compiler to use: release, amd-staging, amd-mainline, or leave blank (default).') string( name: 'COMPILER_COMMIT', defaultValue: '', @@ -816,16 +795,16 @@ pipeline { description: "Run the ck_tile FMHA tests (default: OFF)") booleanParam( name: "RUN_CK_TILE_GEMM_TESTS", - defaultValue: false, - description: "Run the ck_tile GEMM tests (default: OFF)") + defaultValue: true, + description: "Run the ck_tile GEMM tests (default: ON)") booleanParam( name: "BUILD_INSTANCES_ONLY", defaultValue: false, description: "Test building instances for various architectures simultaneously (default: OFF)") booleanParam( name: "BUILD_GFX12", - defaultValue: false, - description: "Build CK and run tests on gfx12 (default: OFF)") + defaultValue: true, + description: "Build CK and run tests on gfx12 (default: ON)") booleanParam( name: "NINJA_BUILD_TRACE", defaultValue: false, @@ -1019,7 +998,7 @@ pipeline { environment{ setup_args = "NO_CK_BUILD" execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ - make -j64 tile_example_gemm_basic && \ + make -j64 tile_example_gemm_basic tile_example_gemm_universal && \ cd ../ && example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """ } @@ -1038,7 +1017,7 @@ pipeline { environment{ setup_args = "NO_CK_BUILD" execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \ - make -j64 tile_example_gemm_basic && \ + make -j64 tile_example_gemm_basic tile_example_gemm_universal && \ cd ../ && example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """ } @@ -1235,29 +1214,6 @@ pipeline { } } } - - stage("Performance Tests") - { - parallel - { - stage("Run ckProfiler: gfx90a") - { - when { - beforeAgent true - expression { params.RUN_PERFORMANCE_TESTS.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } - } - options { retry(1) } - agent{ label rocmnode("gfx90a")} - environment{ - setup_args = "NO_CK_BUILD" - } - steps{ - runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') - cleanWs() - } - } - } - } stage("Process Performance Test Results") { parallel diff --git a/LICENSE b/LICENSE index 581b5efde535f686aa0a584709fccaa92353d125..68f6ae5746ddd5c79829fd7cc7e32584ffd9d822 100644 --- a/LICENSE +++ b/LICENSE @@ -7,7 +7,7 @@ Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou) Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan) SPDX-License-Identifier: MIT -Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 302173dc1729fde4c1f81277193031a291796149..95f44d887263260b89c8a73048cf4d78c3405974 100644 --- a/README.md +++ b/README.md @@ -26,23 +26,15 @@ The current CK library is structured into four layers: ## General information -To build our documentation locally, use the following code: - -``` bash -cd docs -pip3 install -r sphinx/requirements.txt -python3 -m sphinx -T -E -b html -d _build/doctrees -D language=en . _build/html -``` - -You can find a list of our developers and contributors on our [Contributors](/CONTRIBUTORS.md) page. - -```note -If you use CK, cite us as follows: - -* [Realizing Tensor Operators Using Coordinate Transformations and Tile Based Programming](???): - This paper will be available on arXiv soon. -* [CITATION.cff](/CITATION.cff) -``` +* [CK supported operations](include/ck/README.md) +* [CK Tile supported operations](include/ck_tile/README.md) +* [CK wrapper](client_example/25_wrapper/README.md) +* [CK codegen](codegen/README.md) +* [CK profiler](profiler/README.md) +* [Examples (Custom use of CK supported operations)](example/README.md) +* [Client examples (Use of CK supported operations with instance factory)](client_example/README.md) +* [Terminology](/TERMINOLOGY.md) +* [Contributors](/CONTRIBUTORS.md) CK is released under the **[MIT license](/LICENSE)**. @@ -129,6 +121,15 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa You can find instructions for running each individual example in [example](/example). +* Build and run smoke/regression examples and tests: + + ```bash + make -j smoke # tests and examples that run for < 30 seconds each + ``` + ```bash + make -j regression # tests and examples that run for >= 30 seconds each + ``` + * Build ckProfiler: ```bash @@ -137,6 +138,14 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa You can find instructions for running ckProfiler in [profiler](/profiler). +* Build our documentation locally: + + ``` bash + cd docs + pip3 install -r sphinx/requirements.txt + python3 -m sphinx -T -E -b html -d _build/doctrees -D language=en . _build/html + ``` + Note the `-j` option for building with multiple threads in parallel, which speeds up the build significantly. However, `-j` launches unlimited number of threads, which can cause the build to run out of memory and crash. On average, you should expect each thread to use ~2Gb of RAM. @@ -153,9 +162,11 @@ Additional cmake flags can be used to significantly speed-up the build: `batched_gemm_multi_d_dl`. These instances are useful on architectures like the NAVI2x, as most other platforms have faster instances, such as `xdl` or `wmma`, available. +* `DPP_KERNELS` (default is OFF) must be set to ON in order to build instances, such as `gemm_dpp`. + These instances are useful on architectures like the NAVI2x, as most other platforms have faster instances, such as `xdl` or `wmma`, available. + * `CK_USE_FP8_ON_UNSUPPORTED_ARCH` (default is OFF) must be set to ON in order to build instances, - such as `gemm_universal` and `gemm_multiply_multiply` for fp8 data type for GPU targets which do not - have native support for fp8 data type, such as gfx908 or gfx90a. These instances are useful on + such as `gemm_universal`, `gemm_universal_streamk` and `gemm_multiply_multiply` for fp8 data type for GPU targets which do not have native support for fp8 data type, such as gfx908 or gfx90a. These instances are useful on architectures like the MI100/MI200 for the functional support only. ## Using sccache for building diff --git a/TERMINOLOGY.md b/TERMINOLOGY.md new file mode 100644 index 0000000000000000000000000000000000000000..e8833efb89d68674a3c8b8894047e16c8d45e67f --- /dev/null +++ b/TERMINOLOGY.md @@ -0,0 +1,2 @@ +[Back to the main page](./README.md) +# Composable Kernel terminology \ No newline at end of file diff --git a/client_example/01_gemm/README.md b/client_example/01_gemm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6dcd1e29598a4be40f117c65b1101bda62639ccf --- /dev/null +++ b/client_example/01_gemm/README.md @@ -0,0 +1,126 @@ +[Back to supported operations](../../../include/ck/README.md) +# Composable Kernel GEMM + +## GEMM +General matrix multiplications operation. In CK GEMM operation is called as `DeviceGemm` and requires following types as template parameters: + +* **ALayout** - A matrix layout (RowMajor/ColumnMajor). +* **BLayout** - B matrix layout (RowMajor/ColumnMajor). +* **CLayout** - B matrix layout (RowMajor/ColumnMajor). +* **ADataType** - A matrix data type. +* **BDataType** - B matrix data type. +* **CDataType** - B matrix data type. +* **AElementwiseOperation** - Fused operation on tensor A before GEMM. +* **BElementwiseOperation** - Fused operation on tensor B before GEMM. +* **CElementwiseOperation** - Fused operation on tensor C after GEMM. + +For matrices with large K dimension `DeviceGemmSplitK` implementation is available. This implementation allows user to split K dimension between work groups. This implementation uses `AtomicAdd` operation on global memory, thus need to zero-out output buffer for correct results. + +For fused operations with additional tensor there are `DeviceGemmMultipleABD` or `DeviceGemmMultipleD` operation which require following parameters: +* **DsLayout** - layouts for additional tensors for fused operations. +* **DsDataType** - data types for additional tensors for fused operations. + +For `DeviceGemmMultipleABD` **ALayout**, **BLayout**, **ADataType** and **BDataType** user should pass a tuple. + +List of the device operations in CK: + +* **DeviceGemmDl** - Device operation with DL instructions. +* **DeviceGemmDpp** - Device operation with DL instructions with DPP instructions during data load. +* **DeviceGemmWmma_CShuffle** - Device operation with WMMA instructions with CShuffle optimization for more optimized data store. +* **DeviceGemm_Xdl_CShuffle_LdsDirectLoad** - Device operation with XDL instructions and CShuffle optimization for more optimized data store and direct load from global memory to shared memory. +* **DeviceGemm_Xdl_CShuffle** - Device operation with XDL instructions with CShuffle optimization for more optimized data store. +* **DeviceGemm_Xdl_CShuffleV2** - Device operation with XDL instructions with CShuffle optimization for more optimized data store. GEMM pipeline has been optimized compared to **DeviceGemm_Xdl_CShuffle**. +* **DeviceGemmXdlSkipBLds** - Device operation with XDL instructions. Load to shared memory has been skiped for B matrix. +* **DeviceGemm_Xdl_WaveletModel_CShuffle** - Device operation with XDL instructions with CShuffle optimization for more optimized data store. Producer and consumer scheme cooperation between waves in workgroup. +* **DeviceGemmXdl** - Device operation with XDL instructions. + +Table of supported cases by instance factory with XDL instruction for Row/Row/Row, Row/Column/Row, Column/Row/Row or Column/Column/Row: + +| |Is supported| +|-------|---| +|bf16|✓| +|fp16|✓| +|fp32|✓| +|int8|✓| +|fp8 |✓| + +Table of supported cases by instance factory with WMMA instruction for Row/Row/Row, Row/Column/Row, Column/Row/Row or Column/Column/Row: + +| |Is supported| +|-------|---| +|bf16|✓| +|fp16|✓| +|fp32|✗| +|int8|✓| +|fp8 |✗| + +Table of supported cases by instance factory with DL instruction for Row/Row/Row, Row/Column/Row, Column/Row/Row or Column/Column/Row: + +| |Is supported| +|-------|---| +|bf16|✗| +|fp16|✓| +|fp32|✓| +|int8|✓| +|fp8 |✗| + +Table of supported cases by instance factory with fused output elementwise operation: + +* **B Matrix Multiply + Add + Gelu** - bf16 (int8 for B matrix) +* **B Matrix Multiply + Add** - bf16 (int8 for B matrix) +* **B Matrix Multiply + Gelu** - bf16 (int8 for B matrix) +* **B Matrix Multiply** - bf16 (int8 for B matrix) + +* **Add + Add + Gelu** - fp16 +* **Add + Gelu** - fp16, bf16 (int8 for B matrix) for Row/Column/Row +* **Multiply** - fp16 +* **Add + Multiply** - fp16 +* **Add + Relu** - fp16 (int8 for B matrix) for Row/Column/Row, bf16 (int8 for B matrix) for Row/Column/Row +* **Add + Silu** - fp16 (int8 for B matrix) for Row/Column/Row, bf16 (int8 for B matrix) for Row/Column/Row +* **Add** - fp16 (int8 for B matrix) for Row/Column/Row, bf16 (int8 for B matrix) for Row/Column/Row +* **Bilinear** - fp16, int8 +* **Gelu** - fp16 +* **Multiply + Add** - fp16 for Row/Column/Row and Row/Row/Row, fp16 (int8 for B matrix, fp32 for Bias) for Row/Column/Row and Row/Row/Row, +* **Quantization** - int8 + +## GEMM V2 (Universal GEMM) +General matrix multiplications operation optimized for MI300 series. Operation is called as `DeviceGemmV2` and requires following types as template parameters: + +* **ALayout** - A matrix layout (RowMajor/ColumnMajor). +* **BLayout** - B matrix layout (RowMajor/ColumnMajor). +* **CLayout** - B matrix layout (RowMajor/ColumnMajor). +* **ADataType** - A matrix data type. +* **BDataType** - B matrix data type. +* **CDataType** - B matrix data type. +* **AElementwiseOperation** - Fused operation on tensor A before GEMM. +* **BElementwiseOperation** - Fused operation on tensor B before GEMM. +* **CElementwiseOperation** - Fused operation on tensor C after GEMM. + +This implementation allows user to split K dimension between work groups. This implementation requires AtomicAdd operation on global memory (output buffer must be set to zeroes if splitK parameter is larger than one). + +List of the device operations for in CK: + +* **DeviceGemm_Xdl_CShuffleV3** - Device operation with XDL instructions with CShuffle optimization for more optimized data store. +* **DeviceGemm_Xdl_CShuffleV3R1** - Device operation with XDL instructions with CShuffle optimization for more optimized data store. This implementation perform reduction on splitted K dimension after GEMM instead of AtomicAdd instruction. + +Table of supported cases by instance factory with XDL instruction for Row/Row/Row, Row/Column/Row, Column/Row/Row or Column/Column/Row: + +| |Is supported| +|-------|---| +|bf16|✓| +|fp16|✓| +|fp32|✗| +|int8|✗| +|fp8 (C bf16)|✓| +|fp16 (A fp8)|✓| +|fp16 (B fp8)|✓| + +## Others + +* **DeviceGemm_dequantB** - GEMM with dequantization (implemented with WMMA instructions). +* **DeviceGemmMultipleD_ABScale** - GEMM with scale for A and B matrix. +* **DeviceGemmMultipleDLayernorm** - GEMM fused with layernorm. +* **DeviceGemmMultipleDMultipleR** - GEMM fused with reductions and custom global reductions operators. +* **DeviceGemmReduce** - GEMM fused with reduction. +* **DeviceGemm_Streamk_V2** - GEMM stream K implementation. Implementation allows to use reduction instead of AtomicAdd. +* **DeviceGemmStreamK** - GEMM stream K implementation using AtomicAdd. diff --git a/client_example/07_grouped_convnd_fwd/CMakeLists.txt b/client_example/07_grouped_convnd_fwd/CMakeLists.txt index c953e21d0266f61b1b9fc99de9252a4f00bd57cd..2ea31bdf068149c310b5e90647217e9000884dea 100644 --- a/client_example/07_grouped_convnd_fwd/CMakeLists.txt +++ b/client_example/07_grouped_convnd_fwd/CMakeLists.txt @@ -22,4 +22,7 @@ if(GPU_TARGETS MATCHES "gfx9") add_executable(client_grouped_conv3d_fwd_bf8_fp8 grouped_conv3d_fwd_bf8_fp8.cpp) target_link_libraries(client_grouped_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations) endif() + + add_executable(grouped_conv2d_fwd_ngchw grouped_conv2d_fwd_ngchw.cpp) + target_link_libraries(grouped_conv2d_fwd_ngchw PRIVATE composable_kernel::device_conv_operations) endif() diff --git a/client_example/07_grouped_convnd_fwd/README.md b/client_example/07_grouped_convnd_fwd/README.md new file mode 100644 index 0000000000000000000000000000000000000000..28a64ad7337f6ac1bdb218c4b48a74085b2a369d --- /dev/null +++ b/client_example/07_grouped_convnd_fwd/README.md @@ -0,0 +1,68 @@ +[Back to supported operations](../../../include/ck/README.md) +# Composable Kernel Grouped Convolution + +## Grouped Convolution Forward +Grouped convolution operation for 1D, 2D or 3D spatial dimensions. Convolution utilizes GEMM kernel after tensor coordinate transform. In CK Grouped Convolution Forward operation is called as `DeviceGroupedConvFwdMultipleABD` and requires following types as template parameters: + +* **NumDimSpatial** - number of spatial dimensions (1D, 2D, 3D). +* **InLayout** - input layout (NHWGC, GNHWC, NGCHW). +* **WeiLayout** - weight layout (GKYXC). +* **DsLayout** - layouts for additional tensors for fused operations. +* **OutLayout** - output layout (NHWGK, GNHWK, NGKHW). +* **ADataType** - input data type. Pass tuple if there is fused operation with input. +* **BDataType** - weight data type. Pass tuple if there is fused operation with weight. +* **DsDataType** - data types for additional tensors for fused operations. +* **EDataType** - Output data type. +* **AElementwiseOperation** - fused operation on tensor A (input). +* **BElementwiseOperation** - fused operation on tensor B (weight). +* **CDEElementwiseOperation** - fused operation on tensor C (output). +* **AComputeType** - compute data type of tensor A for mfma instruction (ADataType by default). +* **BComputeType** - compute data type of tensor B for mfma instruction (AComputeType by default). + +Grouped convolution forward support tensors larger than 2GB. + +List of the device operations for grouped convolution forward in CK: + +* **DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3** - Device operation with XDL instructions. Optimized for AMD Instinct MI300 series. +* **DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle** - Device operation with XDL instructions and support of fused operations to input, weight and output. +* **DeviceGroupedConvFwdMultipleD_Wmma_CShuffle** - Device operation with WMMA instructions. +* **DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK** - Device operation with DL instructions. + +Table of supported cases by instance factory with XDL instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|bf16 |2D, 3D|2D|1D, 2D, 3D| +|fp16 |2D, 3D|2D|1D, 2D, 3D| +|fp32 |2D, 3D|2D|1D, 2D, 3D| +|int8 |2D, 3D|2D|1D, 3D| +|fp8 |3D|✗|✗| +|bf8 |3D|✗|✗| + +Table of supported cases by instance factory with WMMA instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|fp16 |2D, 3D|✗|2D, 3D| +|int8 |2D, 3D|✗|2D, 3D| + +Table of supported cases by instance factory with DL instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|bf16 |✗|✗|2D| +|fp16 |✗|✗|2D| +|fp32 |✗|✗|2D| +|int8 |✗|✗|2D| + +Table of supported cases by instance factory with fused elementwise operation: + +* **Dynamic elementwise operation** - 2D/3D, NHWGC, bf16/fp16/fp32/int8 +* **Bilinear** - 3D, NHWGC, bf16/fp16/fp32/int8 +* **ConvInvScale** - 3D, NHWGC, fp8 +* **ConvScale** - 3D, NHWGC, fp8/bf8 +* **ConvScale + Add** - 3D, NHWGC, fp8 +* **ConvScale + Relu** - 3D, NHWGC, fp8 +* **Scale** - 3D, NHWGC, bf16/fp16/fp32/int8 +* **Scale + Add (for A and B)** - 3D, NHWGC, bf16/fp16/fp32/int8 +* **Scale + Add + Scale + Add + Relu** - 3D, NHWGC, bf16/fp16/fp32/int8 diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp new file mode 100644 index 0000000000000000000000000000000000000000..480abf23d24747f2f4a3e93ba09f8cddbd058139 --- /dev/null +++ b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +// Use std tuple instead of ck tuple to avoid clang +// implicit instantiation of undefined template error. +using DDataTypes = std::tuple; + +using InLayout = ck::tensor_layout::convolution::NGCHW; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using OutLayout = ck::tensor_layout::convolution::NGKHW; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 64; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 14; // input H +static constexpr ck::index_t Wi = 14; // input W +static constexpr ck::index_t Ho = 14; // output H +static constexpr ck::index_t Wo = 14; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int execute_conv_fwd() +{ + std::array in_lengths{G, N, C, Hi, Wi}; + std::array in_strides{C * Hi * Wi, G * C * Hi * Wi, Hi * Wi, Wi, 1}; + std::array wei_lengths{G, K, C, Y, X}; + std::array wei_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; + std::array out_lengths{G, N, K, Ho, Wo}; + std::array out_strides{K * Ho * Wo, G * K * Ho * Wo, Ho * Wo, Wo, 1}; + + std::array filter_strides{1, 1}; + std::array filter_dilations{1, 1}; + std::array input_left_pads{1, 1}; + std::array input_right_pads{1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + PassThrough>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {}, + {}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + // workspace_sz will be equal to 0 for other layout than NGCHW + const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace_dev(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = + std::size_t(2) * G * N * K * C * Ho * Wo * Y * X + 3 * N * Ho * Wo * G * K; + std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C + + sizeof(WeiDataType) * G * K * Y * X * C + + sizeof(OutDataType) * 2 * N * Ho * Wo * G * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {}, + {}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace_dev(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return 0; +} + +int main() { return execute_conv_fwd(); } diff --git a/client_example/10_grouped_convnd_bwd_data/README.md b/client_example/10_grouped_convnd_bwd_data/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0ed133310e1f4af03e232a6177fca09ad7467240 --- /dev/null +++ b/client_example/10_grouped_convnd_bwd_data/README.md @@ -0,0 +1,48 @@ +[Back to supported operations](../../../include/ck/README.md) +# Composable Kernel Grouped Convolution + +## Grouped Convolution Backward Data + +Grouped convolution operation for 1D, 2D or 3D spatial dimensions. Convolution utilizes GEMM kernel after tensor coordinate transform. In CK Grouped Convolution Backward Data operation is called as `DeviceGroupedConvBwdDataMultipleD` and requires following types as template parameters: + +* **NumDimSpatial** - number of spatial dimensions (1D, 2D, 3D). +* **ALayout** - output layout (NHWGK, GNHWK, NGKHW). +* **BLayout** - weight layout (GKYXC). +* **DsLayout** - layouts for additional tensors for fused operations. +* **ELayout** - input layout (NHWGC, GNHWC, NGCHW). +* **ADataType** - output data type. +* **BDataType** - weight data type. +* **DsDataType** - data types for additional tensors for fused operations. +* **EDataType** - input data type. +* **AElementwiseOperation** - fused operation on tensor A (output). +* **BElementwiseOperation** - fused operation on tensor B (weight). +* **CDEElementwiseOperation** - fused operation on tensor C (input). +* **AComputeType** - compute data type of tensor A for mfma instruction (ADataType by default). +* **BComputeType** - compute data type of tensor B for mfma instruction (AComputeType by default). + +Grouped convolution backward data supports tensors larger than 2GB (except when image is larger than 2GB). + +List of the device operations for grouped convolution backward data in CK: + +* **DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1** - Device operation with XDL instructions and support of fused operations to input. +* **DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle** - Device operation with WMMA instructions. + +Table of supported cases by instance factory with XDL instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|bf16|2D, 3D|✗|2D, 3D| +|fp16 |2D, 3D|✗|2D, 3D| +|fp32 |2D, 3D|✗|2D, 3D| + +Table of supported cases by instance factory with WMMA instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|fp16 |2D, 3D|✗|2D, 3D| +|int8 |2D, 3D|✗|2D, 3D| + +Table of supported cases by instance factory with fused elementwise operation: + +* **Bilinear** - 3D, NHWGC, bf16/fp16/fp32 +* **Scale** - 3D, NHWGC, bf16/fp16/fp32 diff --git a/client_example/11_grouped_conv_bwd_weight/README.md b/client_example/11_grouped_conv_bwd_weight/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ed3dff0f1e80dd494fa2b1c5eb051a9a65ba34d4 --- /dev/null +++ b/client_example/11_grouped_conv_bwd_weight/README.md @@ -0,0 +1,62 @@ +[Back to supported operations](../../../include/ck/README.md) +# Composable Kernel Grouped Convolution + +## Grouped Convolution Backward Weight + +Grouped convolution operation for 1D, 2D or 3D spatial dimensions. Convolution utilizes GEMM kernel after tensor coordinate transform. Backward weight version uses splitK feature (due to large GEMM K dimension). In CK Grouped Convolution Backward Weight operation is called as `DeviceGroupedConvBwdWeight` and requires following types as template parameters: + +* **NumDimSpatial** - number of spatial dimensions (1D, 2D, 3D). +* **InLayout** - input layout (NHWGC, GNHWC, NGCHW). +* **WeiLayout** - weight layout (GKYXC). +* **OutLayout** - output layout (NHWGK, GNHWK, NGKHW). +* **InDataType** - input data type. +* **WeiDataType** - weight data type. +* **OutDataType** - output data type. +* **InElementwiseOperation** - fused operation on tensor input. +* **WeiElementwiseOperation** - fused operation on tensor weight. +* **OutElementwiseOperation** - fused operation on tensor output. +* **ComputeTypeA** - compute data type of tensor A for mfma instruction (ADataType by default). +* **ComputeTypeB** - compute data type of tensor B for mfma instruction (ComputeTypeA by default). + +For fused operations with additional tensor there is `DeviceGroupedConvBwdWeightMultipleD` operation which requires following parameters: +* **DsLayout** - layouts for additional tensors for fused operations. +* **DsDataType** - data types for additional tensors for fused operations. + +Grouped convolution backward weight doesn't supports tensors larger than 2GB. + +List of the device operations for grouped convolution backward weight in CK: + +* **DeviceGroupedConvBwdWeight_Xdl_CShuffle** - Device operation with XDL instructions. +* **DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle** - Device operation with XDL instructions. Optimized for small C or K. +* **DeviceGroupedConvBwdWeight_Wmma_CShuffle** - Device operation with WMMA instructions. +* **DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle** - Device operation with XDL instructions and support of fused operations to output. +* **DeviceGroupedConvBwdWeight_Dl** - Device operation with DL instructions. + +Table of supported cases by instance factory with XDL instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|bf16|2D, 3D|✗|✗| +|bf16(fp32 for weight)|2D, 3D|✗|1D, 2D, 3D| +|fp16 |2D, 3D|✗|1D, 2D, 3D| +|fp32 |2D, 3D|✗|1D, 2D, 3D| + +Table of supported cases by instance factory with WMMA instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|fp16 |3D|✗|3D| +|int8 |3D|✗|3D| + +Table of supported cases by instance factory with DL instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|bf16(fp32 for weight)|1D, 2D, 3D|✗|1D, 2D, 3D| +|fp16 |1D, 2D, 3D|✗|1D, 2D, 3D| +|fp32 |1D, 2D, 3D|✗|1D, 2D, 3D| + +Table of supported cases by instance factory with fused elementwise operation: + +* **Bilinear** - 3D, NHWGC, bf16(fp32 for weight)/fp16/fp32 +* **Scale** - 3D, NHWGC, bf16(fp32 for weight)/fp16/fp32 diff --git a/client_example/24_grouped_conv_activation/CMakeLists.txt b/client_example/24_grouped_conv_activation/CMakeLists.txt index dc55250bfe78ea11e48ad75fd0d522a54e66e9ae..67bbdfec4505bec90285366e8fab670e0685e247 100644 --- a/client_example/24_grouped_conv_activation/CMakeLists.txt +++ b/client_example/24_grouped_conv_activation/CMakeLists.txt @@ -54,7 +54,7 @@ target_link_libraries(client_conv3d_fwd_convscale_relu_amax_fp8 PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_reduction_operations - utility) + composable_kernel::utility) # Fwd convscale + AMAX add_executable(client_conv3d_fwd_convscale_amax_fp8 grouped_convnd_fwd_convscale_reduce/conv3d_fwd_convscale_amax_fp8.cpp) @@ -62,7 +62,7 @@ target_link_libraries(client_conv3d_fwd_convscale_amax_fp8 PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_reduction_operations - utility) + composable_kernel::utility) # Fwd convscale add_executable(client_conv3d_fwd_convscale_fp8 grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp) diff --git a/client_example/25_wrapper/README.md b/client_example/25_wrapper/README.md index eba3de017f41bf7cc7de8a4f2cd581dfd81c0093..3db9a9af44e5422da630f6328c69c69820c857a1 100644 --- a/client_example/25_wrapper/README.md +++ b/client_example/25_wrapper/README.md @@ -1,14 +1,9 @@ +[Back to the main page](../../README.md) # Composable Kernel wrapper GEMM tutorial -This tutorial demonstrates how to implement matrix multiplication using Composable Kernel (CK) -wrapper. We present the base version of GEMM without most of the available optimizations; however, -it's worth noting that CK has kernels with different optimizations. +This tutorial demonstrates how to implement matrix multiplication using Composable Kernel (CK) wrapper. We present the base version of GEMM without most of the available optimizations; however, it's worth noting that CK has kernels with different optimizations. -To implement these optimizations, you can use the CK wrapper or directly use available instances in -CK. You can also refer to the -[optimized GEMM example](https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_optimized_gemm.cpp), -that uses CK wrapper based on the -[`gridwise_gemm_xdlops_v2r3`](https://github.com/ROCm/composable_kernel/blob/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp) implementation. +To implement these optimizations, you can use the CK wrapper or directly use available instances in CK. You can also refer to the [optimized GEMM example](https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_optimized_gemm.cpp), that uses CK wrapper based on the [`gridwise_gemm_xdlops_v2r3`](https://github.com/ROCm/composable_kernel/blob/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp) implementation. The kernel definition should look similar to: diff --git a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp index 4b284c74d4a75cf3634b77e703d369df44ba8098..47d3e0abf94d019761be9a8c667abbeb57a905ea 100644 --- a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp +++ b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp @@ -121,7 +121,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co constexpr ck::index_t NumDTensor = 2; using GroupedGemmKernelArgument = - ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments; + ck::tensor_operation::device::GroupedGemmKernelArgument; std::vector grouped_gemm_kernel_args_; grouped_gemm_kernel_args_.reserve(group_count); diff --git a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp index 6cc83e06f68555f84d642219e3f49aafd627fa66..8c705d3bcc78a3f16e75ac98def07c0922770e83 100644 --- a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp +++ b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp @@ -120,7 +120,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co constexpr ck::index_t NumDTensor = 1; using GroupedGemmKernelArgument = - ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments; + ck::tensor_operation::device::GroupedGemmKernelArgument; std::vector grouped_gemm_kernel_args_; grouped_gemm_kernel_args_.reserve(group_count); diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index acb57d7bb045e8d7d917fc42ba5ea1e04e4cb252..9e2012bf8a7ce28333ac51167d3274144164a570 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -56,13 +56,21 @@ if (GPU_TARGETS) add_definitions(-DCK_USE_WMMA) set(CK_USE_WMMA "ON") endif() + if (GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx950") + add_definitions(-DCK_USE_OCP_FP8) + set(CK_USE_OCP_FP8 "ON") + endif() + if (GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx94") + add_definitions(-DCK_USE_FNUZ_FP8) + set(CK_USE_FNUZ_FP8 "ON") + endif() else() add_definitions(-DCK_USE_WMMA -DCK_USE_XDL) set(CK_USE_XDL "ON") set(CK_USE_WMMA "ON") endif() -find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_reduction_operations) +find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_reduction_operations utility) if(GPU_TARGETS MATCHES "gfx9") find_package(composable_kernel COMPONENTS device_contraction_operations) endif() diff --git a/client_example/README.md b/client_example/README.md index 64a7130d537b1e2fb8752c4031e8430d11a6a46a..d9f793434db9dd0786aef36efca43432403e0b98 100644 --- a/client_example/README.md +++ b/client_example/README.md @@ -1,3 +1,5 @@ +[Back to the main page](../README.md) +# Composable Kernel client examples ## Client application links to CK library, and therefore CK library needs to be installed before building client applications. diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index 93fd306e98af3bf86fcc6f0f213d029f6f3c4a26..fb2b38d688d4141a91be7b48cd96b024edfb0d59 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -66,7 +66,7 @@ else() -Wunreachable-code -Wunused -Wno-reserved-identifier - -Werror + -Werror -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt index 1ca0d12821067edb2ebdd8b763a894168e70e647..45c47672b0ab01d53e09456616d376a49b757353 100644 --- a/codegen/CMakeLists.txt +++ b/codegen/CMakeLists.txt @@ -7,6 +7,7 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..) +configure_file(${CK_ROOT}/include/ck/config.h.in ${CK_ROOT}/include/ck/config.h) find_package(ROCM) include(ROCMInstallTargets) diff --git a/codegen/README.md b/codegen/README.md new file mode 100644 index 0000000000000000000000000000000000000000..deadf3221dfea8a47c6f65a861bdaf37981ddb89 --- /dev/null +++ b/codegen/README.md @@ -0,0 +1,2 @@ +[Back to the main page](../README.md) +# Composable Kernel codegen \ No newline at end of file diff --git a/codegen/driver/main.cpp b/codegen/driver/main.cpp index c7d295de943e1feb5b139d933118db745e7edee3..7b878d0d579635b5022a7a0167c73daf68e9d134 100644 --- a/codegen/driver/main.cpp +++ b/codegen/driver/main.cpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/codegen/src/headers.cpp b/codegen/src/headers.cpp index 5b0c929db32fc9a834ec040163b2a5476e2b6d1c..452cd998469702603ed2f33e931ba2622e16ba87 100644 --- a/codegen/src/headers.cpp +++ b/codegen/src/headers.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/headers.hpp" #include "ck_headers.hpp" diff --git a/codegen/src/types.cpp b/codegen/src/types.cpp index a8a8b10c04d522e93dc7340167e46ca83f51259b..9aa5d39fae34de038665b140d87da964757ac76e 100644 --- a/codegen/src/types.cpp +++ b/codegen/src/types.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/types.hpp" #include "ck/host/stringutils.hpp" #include diff --git a/codegen/test/gemm_multiple_d.cpp b/codegen/test/gemm_multiple_d.cpp index bd7ef463fbe64d5bc3d07665cb4757598657f2ad..9e2d990d9bf5b4ee2434dd9b9700e02e57cde1e4 100644 --- a/codegen/test/gemm_multiple_d.cpp +++ b/codegen/test/gemm_multiple_d.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/device_gemm_multiple_d/problem.hpp" #include "ck/host/device_gemm_multiple_d/operation.hpp" #include "ck/host/headers.hpp" diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp index 50290fa25ad385ce1e657de8cac3042227cc6787..9902caab0496eaa7a182ff2c7896992bc9908cdf 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" #include "ck/host/headers.hpp" diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp index b558d97c783c7f9a2a0901013c5ea4cee053d175..205283e7aad2bd94f09d50063d67ecc1e1bd9ed3 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" #include "ck/host/headers.hpp" diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp index e2972a93d2f7e11aa6b03dc35b7aae6663e70f93..2b83af24321dc021b35386cc2f25b4ca7da7d102 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" #include "ck/host/headers.hpp" diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp index b728096c51e80d0d72f407160275e3699cf5a16a..fbe27e9c8b82f9b5ddf339a11bfc4d5e3cf92c80 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" #include "ck/host/headers.hpp" diff --git a/codegen/test/include/common.hpp b/codegen/test/include/common.hpp index 99d4c6497331f65d19adf302bd47dbaa22ac4b40..24fde2e52358688f1c9ab4ba9d68cd47a0d9a76a 100644 --- a/codegen/test/include/common.hpp +++ b/codegen/test/include/common.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include #include diff --git a/codegen/test/rtc/include/rtc/compile_kernel.hpp b/codegen/test/rtc/include/rtc/compile_kernel.hpp index c4413b47be2b23a36dd2a631794876dae8b98776..a49714f7c6850fd83c443592539c3f6e4a0beded 100644 --- a/codegen/test/rtc/include/rtc/compile_kernel.hpp +++ b/codegen/test/rtc/include/rtc/compile_kernel.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL #define GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL diff --git a/codegen/test/rtc/include/rtc/hip.hpp b/codegen/test/rtc/include/rtc/hip.hpp index e962d4cd3e1e1573b13272e052eeee646b05dec1..3163bb08edad50baf19566de420c9aa252ed066e 100644 --- a/codegen/test/rtc/include/rtc/hip.hpp +++ b/codegen/test/rtc/include/rtc/hip.hpp @@ -1,8 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_HIP #define GUARD_HOST_TEST_RTC_INCLUDE_RTC_HIP #include #include +#include #include #include diff --git a/codegen/test/rtc/include/rtc/kernel.hpp b/codegen/test/rtc/include/rtc/kernel.hpp index 9f38e90416e0d2363a921df1ac4268bbc82e55ff..b1ee729f77518f2ddf312c2965454e43e589d8cb 100644 --- a/codegen/test/rtc/include/rtc/kernel.hpp +++ b/codegen/test/rtc/include/rtc/kernel.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_KERNEL #define GUARD_HOST_TEST_RTC_INCLUDE_RTC_KERNEL diff --git a/codegen/test/rtc/include/rtc/manage_ptr.hpp b/codegen/test/rtc/include/rtc/manage_ptr.hpp index 92edf1262832d5e69c5751b162d9b5a43aac5a58..52b94d4b70ba3eecb8f2f9b20bf7d7e39e6fa2e0 100644 --- a/codegen/test/rtc/include/rtc/manage_ptr.hpp +++ b/codegen/test/rtc/include/rtc/manage_ptr.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_MANAGE_POINTER #define GUARD_HOST_TEST_RTC_INCLUDE_RTC_MANAGE_POINTER diff --git a/codegen/test/rtc/include/rtc/tmp_dir.hpp b/codegen/test/rtc/include/rtc/tmp_dir.hpp index a0a2cb9b77480f7c32fb531f77a8ad049024dab2..2f3b26cc43549c7e21b84366e5a2d1eb80f203b4 100644 --- a/codegen/test/rtc/include/rtc/tmp_dir.hpp +++ b/codegen/test/rtc/include/rtc/tmp_dir.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_TMP_DIR #define GUARD_HOST_TEST_RTC_INCLUDE_RTC_TMP_DIR diff --git a/codegen/test/rtc/src/compile_kernel.cpp b/codegen/test/rtc/src/compile_kernel.cpp index 8cb71b9043cb92c675ce421d668f95a8886291c2..5a70f898e8cd0b0d97c696d0bf4b41dc290db1f6 100644 --- a/codegen/test/rtc/src/compile_kernel.cpp +++ b/codegen/test/rtc/src/compile_kernel.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include diff --git a/codegen/test/rtc/src/hip.cpp b/codegen/test/rtc/src/hip.cpp index 747f83e3baa240159adcf2e89847f4a1bad245a8..6f16e36720954a7908e3c3ecece4732115797cb1 100644 --- a/codegen/test/rtc/src/hip.cpp +++ b/codegen/test/rtc/src/hip.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include diff --git a/codegen/test/rtc/src/kernel.cpp b/codegen/test/rtc/src/kernel.cpp index 9fe38e84ad6624bcb82d8f3f97a0767ecd92108c..982e95de172fcb2b02633d6f9a6daf7758bf80ba 100644 --- a/codegen/test/rtc/src/kernel.cpp +++ b/codegen/test/rtc/src/kernel.cpp @@ -1,6 +1,10 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include +#include #include // extern declare the function since hip/hip_ext.h header is broken diff --git a/codegen/test/rtc/src/tmp_dir.cpp b/codegen/test/rtc/src/tmp_dir.cpp index 4e89bc35399075d67dcbc03621c1445c6eb6f66b..b36b17cce1cb50a7e14e06f4956f771373570014 100644 --- a/codegen/test/rtc/src/tmp_dir.cpp +++ b/codegen/test/rtc/src/tmp_dir.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 3a2e266ef5ef48470ec10df18eb11cabec34104d..e9df8c9f5ff144152a18c09b6b003294255a7350 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==1.8.5 +rocm-docs-core==1.15.0 sphinxcontrib-bibtex==2.6.3 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index b65d2391f65da7295f84d801b0aae082c23769b1..a42fdf09bf47e7e86533774a3ea18d5ac7eb0608 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -8,6 +8,13 @@ accessible-pygments==0.0.5 # via pydata-sphinx-theme alabaster==0.7.16 # via sphinx +asttokens==3.0.0 + # via stack-data +attrs==24.3.0 + # via + # jsonschema + # jupyter-cache + # referencing babel==2.15.0 # via # pydata-sphinx-theme @@ -25,9 +32,17 @@ cffi==1.16.0 charset-normalizer==3.3.2 # via requests click==8.1.7 - # via sphinx-external-toc + # via + # jupyter-cache + # sphinx-external-toc +comm==0.2.2 + # via ipykernel cryptography==43.0.0 # via pyjwt +debugpy==1.8.12 + # via ipykernel +decorator==5.1.1 + # via ipython deprecated==1.2.14 # via pygithub docutils==0.21.2 @@ -38,20 +53,56 @@ docutils==0.21.2 # pydata-sphinx-theme # sphinx # sphinxcontrib-bibtex +exceptiongroup==1.2.2 + # via ipython +executing==2.1.0 + # via stack-data fastjsonschema==2.20.0 - # via rocm-docs-core + # via + # nbformat + # rocm-docs-core gitdb==4.0.11 # via gitpython gitpython==3.1.43 # via rocm-docs-core +greenlet==3.1.1 + # via sqlalchemy idna==3.7 # via requests imagesize==1.4.1 # via sphinx +importlib-metadata==8.6.1 + # via + # jupyter-cache + # myst-nb +ipykernel==6.29.5 + # via myst-nb +ipython==8.31.0 + # via + # ipykernel + # myst-nb +jedi==0.19.2 + # via ipython jinja2==3.1.4 # via # myst-parser # sphinx +jsonschema==4.23.0 + # via nbformat +jsonschema-specifications==2024.10.1 + # via jsonschema +jupyter-cache==1.0.1 + # via myst-nb +jupyter-client==8.6.3 + # via + # ipykernel + # nbclient +jupyter-core==5.7.2 + # via + # ipykernel + # jupyter-client + # nbclient + # nbformat latexcodec==3.0.0 # via pybtex markdown-it-py==3.0.0 @@ -60,16 +111,48 @@ markdown-it-py==3.0.0 # myst-parser markupsafe==2.1.5 # via jinja2 +matplotlib-inline==0.1.7 + # via + # ipykernel + # ipython mdit-py-plugins==0.4.1 # via myst-parser mdurl==0.1.2 # via markdown-it-py -myst-parser==3.0.1 +myst-nb==1.1.2 # via rocm-docs-core +myst-parser==3.0.1 + # via myst-nb +nbclient==0.10.2 + # via + # jupyter-cache + # myst-nb +nbformat==5.10.4 + # via + # jupyter-cache + # myst-nb + # nbclient +nest-asyncio==1.6.0 + # via ipykernel packaging==24.1 # via + # ipykernel # pydata-sphinx-theme # sphinx +parso==0.8.4 + # via jedi +pexpect==4.9.0 + # via ipython +platformdirs==4.3.6 + # via jupyter-core +prompt-toolkit==3.0.50 + # via ipython +psutil==6.1.1 + # via ipykernel +ptyprocess==0.7.0 + # via pexpect +pure-eval==0.2.3 + # via stack-data pybtex==0.24.0 # via # pybtex-docutils @@ -87,26 +170,45 @@ pygithub==2.3.0 pygments==2.18.0 # via # accessible-pygments + # ipython # pydata-sphinx-theme # sphinx pyjwt[crypto]==2.8.0 # via pygithub pynacl==1.5.0 # via pygithub +python-dateutil==2.9.0.post0 + # via jupyter-client pyyaml==6.0.1 # via + # jupyter-cache + # myst-nb # myst-parser # pybtex # rocm-docs-core # sphinx-external-toc +pyzmq==26.2.0 + # via + # ipykernel + # jupyter-client +referencing==0.36.1 + # via + # jsonschema + # jsonschema-specifications requests==2.32.3 # via # pygithub # sphinx -rocm-docs-core==1.8.5 +rocm-docs-core==1.15.0 # via -r requirements.in +rpds-py==0.22.3 + # via + # jsonschema + # referencing six==1.16.0 - # via pybtex + # via + # pybtex + # python-dateutil smmap==5.0.1 # via gitdb snowballstemmer==2.2.0 @@ -116,6 +218,7 @@ soupsieve==2.5 sphinx==7.4.7 # via # breathe + # myst-nb # myst-parser # pydata-sphinx-theme # rocm-docs-core @@ -149,15 +252,43 @@ sphinxcontrib-qthelp==2.0.0 # via sphinx sphinxcontrib-serializinghtml==2.0.0 # via sphinx +sqlalchemy==2.0.37 + # via jupyter-cache +stack-data==0.6.3 + # via ipython +tabulate==0.9.0 + # via jupyter-cache tomli==2.0.1 # via sphinx +tornado==6.4.2 + # via + # ipykernel + # jupyter-client +traitlets==5.14.3 + # via + # comm + # ipykernel + # ipython + # jupyter-client + # jupyter-core + # matplotlib-inline + # nbclient + # nbformat typing-extensions==4.12.2 # via + # ipython + # myst-nb # pydata-sphinx-theme # pygithub + # referencing + # sqlalchemy urllib3==2.2.2 # via # pygithub # requests +wcwidth==0.2.13 + # via prompt-toolkit wrapt==1.16.0 # via deprecated +zipp==3.21.0 + # via importlib-metadata diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt old mode 100644 new mode 100755 index 52c8ab5806454c787bb2e4648785aecdda245274..97ac21eba5a3722c1f3127e00cdcdb79e01b2634 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -29,10 +29,16 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3) add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3) add_example_executable(example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp) +add_example_executable(example_gemm_xdl_fp16_pk_i4_v3 gemm_xdl_fp16_pk_i4_v3.cpp) +add_example_executable(example_gemm_xdl_fp16_pk_i4_v3_b_scale gemm_xdl_fp16_pk_i4_v3_b_scale.cpp) +add_example_executable(example_gemm_xdl_bf16_pk_i4_v3 gemm_xdl_bf16_pk_i4_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3) add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3) +add_example_executable(example_gemm_xdl_bf16_streamk_v3 gemm_xdl_bf16_streamk_v3.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_streamk_v3) + add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) @@ -42,9 +48,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16) -add_example_executable(example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp) -add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_rtn) - add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_int8) @@ -58,7 +61,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64) add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) -list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) @@ -77,6 +80,9 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8) add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8) +add_example_executable(example_gemm_xdl_fp8_streamk_v3 gemm_xdl_fp8_streamk_v3.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_streamk_v3) + add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8) diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index 6e1c9f2a0da41124f67210c9e65d99a6bceca536..9664c50b6e11ca846d11118de648b908d91cb3b6 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -44,7 +44,7 @@ struct ProblemSizeStreamK final ck::index_t StrideB = -1; ck::index_t StrideC = -1; - ck::index_t NumSKBlocks = -1; + ck::index_t NumSKBlocks = -1; // number of stream-k blocks }; struct ProblemSizeStreamK_universal final { @@ -76,7 +76,7 @@ struct ProblemSizeSplitK final struct ExecutionConfig final { // 0 - no verification, 1 - CPU, 2 - GPU, 3 - CPU + GPU - int do_verification = 3; + int do_verification = 1; int init_method = 2; bool time_kernel = false; }; @@ -287,3 +287,85 @@ bool parse_cmd_args(int argc, return true; } + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} diff --git a/example/01_gemm/gemm_xdl_bf16.cpp b/example/01_gemm/gemm_xdl_bf16.cpp old mode 100644 new mode 100755 diff --git a/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp b/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7b491173a6db282811b6733626f4d632ab2d914d --- /dev/null +++ b/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp @@ -0,0 +1,253 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::pk_i4_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = true; +static constexpr ck::index_t KPerBlock = 128; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 128, + 16, 64, + KPerBlock, 8, 32, + 16, 16, + 1, 2, + S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 32, 32, 0, + 1, 1, S<1, 16, 1, 8>, 4, + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, ADataType, ADataType, PermuteA, PermuteB>; + +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // weight permute + if constexpr(PermuteB) + { + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); + } + } + } + } + else + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j++) + { + b_k_n_permute(i * K + j) = b_k_n(i * K + j); + } + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config); +} + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_bf16_rtn.cpp b/example/01_gemm/gemm_xdl_bf16_rtn.cpp deleted file mode 100644 index 108c100cbdf88c0a8e33e9d372daddf2a56894ab..0000000000000000000000000000000000000000 --- a/example/01_gemm/gemm_xdl_bf16_rtn.cpp +++ /dev/null @@ -1,53 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "common.hpp" - -#include "ck/utility/type_convert.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" - -using ADataType = ck::bhalf_t; -using BDataType = ck::bhalf_t; -using CDataType = ck::bhalf_t; -using AccDataType = float; -using CShuffleDataType = float; - -using ALayout = Row; -using BLayout = Col; -using CLayout = Row; - -using AElementOp = PassThrough; -using BElementOp = PassThrough; -using CElementOp = ck::tensor_operation::element_wise::ConvertBF16RTN; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// clang-format off -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle -// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| -// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| -// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; -// clang-format on - -using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; - -using ReferenceComputeType = float; -using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; - -#include "run_gemm_example.inc" - -int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_bf16_streamk_v3.cpp b/example/01_gemm/gemm_xdl_bf16_streamk_v3.cpp new file mode 100755 index 0000000000000000000000000000000000000000..5b56a43483b85a8f1b9da07ea394feadd2c682b8 --- /dev/null +++ b/example/01_gemm/gemm_xdl_bf16_streamk_v3.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmV2_Streamk_Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_Streamk_V3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 256, + 128, 128, + 64, 8, 8, + 16, 16, + 4, 4, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + 1, 2, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + +#include "run_gemm_example_streamk_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_universal_streamk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index 07d51855d6b60bfa32b4e815627dfbe64a0006d8..414683ffdf63893f95629d52f78a7a95a733b9c4 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -31,9 +31,7 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl // ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>; -// // clang-format on -// clang-format off using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| diff --git a/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp b/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp index 2e27fc66f9456b67ff860e5926f6a24149ab14c3..b0e36b394bb217ea43998923c4f68a75fe413e98 100644 --- a/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp @@ -1,12 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" -using ADataType = ck::f8_t; -using BDataType = ck::half_t; +using ADataType = ck::half_t; +using BDataType = ck::f8_t; using AccDataType = float; using CShuffleDataType = ck::half_t; using CDataType = ck::half_t; @@ -29,15 +29,15 @@ using DeviceGemmV2Instance = AElementOp, BElementOp, CElementOp, GemmDefault, 64, 16, 16, - 64, 16, 8, + 256, 8, 16, 16, 16, 1, 1, - S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, - 2, 16, 16, 0, - S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, + S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, - ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v1>; + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 32, 32, 0, + 1, 1, S<1, 16, 1, 8>, 4, + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, ADataType, ADataType, PermuteA, PermuteB>; + +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // weight permute + if constexpr(PermuteB) + { + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); + } + } + } + } + else + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j++) + { + b_k_n_permute(i * K + j) = b_k_n(i * K + j); + } + } + } + + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int input[8]; + + for(int k = 0; k < 4; k++) + { + int i4x2 = b_k_n_permute(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int hi = input[2]; + int lo = input[0]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 0, i) = i4x2; + } + + { + int hi = input[6]; + int lo = input[4]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 2, i) = i4x2; + } + + { + int hi = input[3]; + int lo = input[1]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 4, i) = i4x2; + } + + { + int hi = input[7]; + int lo = input[5]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 6, i) = i4x2; + } + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config); +} + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp b/example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c8a40baa8ad3b30ef3f7808accc0d28b47684fd8 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp @@ -0,0 +1,357 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::pk_i4_t; +using BScaleDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = true; + +static constexpr ck::index_t Scale_Block_N = 1; +static constexpr ck::index_t Scale_Block_K = 128; + +static constexpr ck::index_t KPerBlock = 64; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 256, Scale_Block_N, Scale_Block_K, + 128, 128, + KPerBlock, 8, 32, + 32, 32, + 4, 1, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 32, 32, 0, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, PermuteA, PermuteB>; + +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N, + Scale_Stride_BN, + BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 4: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 5: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.5, 0.5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem b1_scale_device_buf(sizeof(BScaleDataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // weight permute + if constexpr(PermuteB) + { + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); + } + } + } + } + else + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j++) + { + b_k_n_permute(i * K + j) = b_k_n(i * K + j); + } + } + } + + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int input[8]; + + for(int k = 0; k < 4; k++) + { + int i4x2 = b_k_n_permute(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int hi = input[2]; + int lo = input[0]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 0, i) = i4x2; + } + + { + int hi = input[6]; + int lo = input[4]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 2, i) = i4x2; + } + + { + int hi = input[3]; + int lo = input[1]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 4, i) = i4x2; + } + + { + int hi = input[7]; + int lo = input[5]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 6, i) = i4x2; + } + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); + b1_scale_device_buf.ToDevice(b1_k_n.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = + gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + Scale_Stride_BN, + static_cast(b1_scale_device_buf.GetDeviceBuffer()), + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + Tensor b_k_n_dequant({K, N}); + + float v_b = 0; + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + ck::pk_i4_t i4x2 = b_k_n(k, n).data; + int8_t i4 = 0; + if(k % 2 == 1) + i4 = (i4x2.data >> 0) & 0xf; + else + i4 = (i4x2.data >> 4) & 0xf; + i4 = i4 - 8; + v_b = ck::type_convert(i4); + + b_k_n_dequant(k, n) = + ck::type_convert(v_b) * + ck::type_convert(b1_k_n(k / Scale_Block_K, n / Scale_Block_N)); + } + } + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n_dequant, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); +} + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp b/example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp index 5b163962b95132e2749151f88c682968417fd361..36ac51f1da59333209e677a135119475156b757a 100644 --- a/example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp @@ -8,7 +8,7 @@ using ADataType = ck::half_t; using BDataType = ck::half_t; using AccDataType = float; -using CShuffleDataType = ck::half_t; +using CShuffleDataType = float; using CDataType = ck::half_t; using ALayout = Row; @@ -43,6 +43,17 @@ using DeviceGemmV2_Streamk_Instance = using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example_streamk_v2.inc" int main(int argc, char* argv[]) { return !run_gemm_universal_streamk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_v3.cpp b/example/01_gemm/gemm_xdl_fp16_v3.cpp index ad370f570efd98e90c2bd53fe7522e5ee249586a..4a969246cd80d3aa6bd27cdfb556ca37368fe091 100644 --- a/example/01_gemm/gemm_xdl_fp16_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp16_v3.cpp @@ -12,7 +12,7 @@ using CShuffleDataType = ck::half_t; using CDataType = ck::half_t; using ALayout = Row; -using BLayout = Row; +using BLayout = Col; using CLayout = Row; using AElementOp = PassThrough; @@ -27,17 +27,17 @@ using DeviceGemmV2Instance = ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, PassThrough, PassThrough, PassThrough, GemmDefault, - 256, - 224, 256, - 64, 8, 2, + 64, + 16, 16, + 256, 8, 8, 16, 16, - 7, 8, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 1, 1, + S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, - S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, - 1, 8, 2, 0, - 1, 2, S<1, 32, 1, 8>, 8, - ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>; + S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + 1, 1, S<1, 16, 1, 4>, 4, + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/example/01_gemm/gemm_xdl_fp8_streamk_v3.cpp b/example/01_gemm/gemm_xdl_fp8_streamk_v3.cpp new file mode 100755 index 0000000000000000000000000000000000000000..3b79ae9b858b49f76ee4eaf039a4cfd31c503d20 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp8_streamk_v3.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmV2_Streamk_Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_Streamk_V3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 256, + 128, 256, + 128, 16, 16, + 16, 16, + 4, 8, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 1, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 1, + 1, 2, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3, ck::f8_t>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + +#include "run_gemm_example_streamk_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_universal_streamk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_streamk.cpp b/example/01_gemm/gemm_xdl_streamk.cpp old mode 100644 new mode 100755 index 5a02457dafd1e021b2c0fa71bd1498c891135304..dbdf7199e857969f4cc3b8af5ae7ea56e696bd97 --- a/example/01_gemm/gemm_xdl_streamk.cpp +++ b/example/01_gemm/gemm_xdl_streamk.cpp @@ -15,7 +15,6 @@ using F16 = ck::half_t; using ALayout = Row; using BLayout = Row; -// using BLayout = Col; using CLayout = Row; using AElementOp = PassThrough; diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index bafec3f358037cc9c4d3758e8395556e5eeb4ce8..4371af6244cb6c8d1d2288f174d4fcb917618743 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -5,88 +5,6 @@ #include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp" -template -inline __host__ __device__ constexpr double get_rtol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 2e-1; - } - else if constexpr(std::is_same_v) - { - return 2e-1; - } - else - { - return 1e-3; - } -} - -template -inline __host__ __device__ constexpr double get_atol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 2e-1; - } - else if constexpr(std::is_same_v) - { - return 2e-1; - } - else - { - return 1e-3; - } -} - template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { @@ -143,8 +61,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) switch(config.init_method) { case 0: - ck::utils::FillConstant{static_cast(1.f)}(a_m_k); - ck::utils::FillConstant{static_cast(1.f)}(b_k_n); + ck::utils::FillConstant{ck::type_convert(1.f)}(a_m_k); + ck::utils::FillConstant{ck::type_convert(1.f)}(b_k_n); break; case 1: ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); diff --git a/example/01_gemm/run_gemm_example_streamk_v2.inc b/example/01_gemm/run_gemm_example_streamk_v2.inc index 8ed8b81bec13c8d1432077589c83d71866387e3c..9ee380d247c2b368493fa03ba627d55c0dcb39c6 100644 --- a/example/01_gemm/run_gemm_example_streamk_v2.inc +++ b/example/01_gemm/run_gemm_example_streamk_v2.inc @@ -3,88 +3,6 @@ #pragma once -template -inline __host__ __device__ constexpr double get_rtol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 1.5e-1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - -template -inline __host__ __device__ constexpr double get_atol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 16.1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 8192.1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { @@ -176,6 +94,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_ref_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; @@ -196,6 +115,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_ref_buf(sizeof(CDataType) * + c_m_n_device_ref_result.mDesc.GetElementSpaceSize()); a_m_k_device_buf.ToDevice(a_m_k.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data()); @@ -240,6 +161,13 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) return true; } + std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument); + if(workspace_size != 0) + { + workspace.Realloc(workspace_size); + gemm.SetWorkSpacePointer(&argument, workspace.GetDeviceBuffer()); + } + bool pass = true; if((config.do_verification == 1) || (config.do_verification == 3)) { @@ -271,6 +199,36 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) #endif } + if((config.do_verification == 2) || (config.do_verification == 3)) + { + // GPU verification + auto ref_gemm_gpu = ReferenceGemmInstanceGPU{}; + auto ref_invoker_gpu = ref_gemm_gpu.MakeInvoker(); + + auto ref_argument_gpu = ref_gemm_gpu.MakeArgument( + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_ref_buf.GetDeviceBuffer()), + M, + N, + K, + a_element_op, + b_element_op, + c_element_op); + + std::cout << "Running verification on GPU." << std::endl; + ref_invoker_gpu.Run(ref_argument_gpu, StreamConfig{}); + + c_m_n_device_ref_buf.FromDevice(c_m_n_device_ref_result.mData.data()); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_device_ref_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + if(config.time_kernel) { ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); diff --git a/example/01_gemm/run_gemm_example_v2.inc b/example/01_gemm/run_gemm_example_v2.inc index 71524fdecf1306938bc381e84ac873fc11142ede..2b60fa5d2867055f841a8bb749d0ab1a910da5f1 100644 --- a/example/01_gemm/run_gemm_example_v2.inc +++ b/example/01_gemm/run_gemm_example_v2.inc @@ -3,88 +3,6 @@ #pragma once -template -inline __host__ __device__ constexpr double get_rtol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 1.5e-1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - -template -inline __host__ __device__ constexpr double get_atol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 16.1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 8192.1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { @@ -261,7 +179,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) if(config.time_kernel) { ave_time = - invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 5, 10, true, 4}); + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 50, 100, true, 4}); std::size_t flop = 2_uz * M * N * K; std::size_t num_btype = diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt index be47665a262ec6619816c249bdfdb96ba3c8ae16..aa9367cdcfc83420fa3015ade75c3090df5bd9de 100644 --- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -16,7 +16,7 @@ if(USE_BITINT_EXTENSION_INT4) add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) endif(USE_BITINT_EXTENSION_INT4) -list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp index ecff7b4713f1ed5ff6402e9ab8ae7198f900e1fe..117a18e3bd9f1d65567c9c881da6519b70d8eca8 100644 --- a/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp @@ -186,15 +186,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); for(int j = 0; j < NumDMatrices; ++j) { - d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } break; default: - a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); for(int j = 0; j < NumDMatrices; ++j) { - d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential{}); } } } @@ -246,7 +246,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co // do GEMM auto argument = gemm.MakeArgument( p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op); - gemm.SetKBatchSize(argument, config.k_batch); + gemm.SetKBatchSize(&argument, config.k_batch); if(!gemm.IsSupportedArgument(argument)) { throw std::runtime_error( @@ -257,7 +257,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); - gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + gemm.SetDeviceKernelArgs(&argument, gemm_arg_dev_mem.GetDeviceBuffer()); invoker.Run(argument, StreamConfig{nullptr, false, 1}); diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp index 965a0e7e37836c06e3aceb29cd19976444c99707..db162fe44440296a203044515fbcc1b3d48ddd79 100644 --- a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp @@ -91,7 +91,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co { auto group_count = problem_size.group_count; - using KernelArguments = ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments; + using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument; using GemmDesc = ck::tensor_operation::device::GemmDesc; // GEMM shape @@ -190,15 +190,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); for(int j = 0; j < NumDs; ++j) { - d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } break; default: - a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); for(int j = 0; j < NumDs; ++j) { - d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential{}); } } } diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp index a193fc39ba637cbd41df4743e0157ca40c38407b..5bdc9931926d88e8b3fb633c490c70adf10bc5c7 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -167,11 +167,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: - a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); } - d0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + d0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); } using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<1>; @@ -254,7 +254,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co gemm.GetDeviceKernelArgSize(&argument), hipMemcpyHostToDevice)); - gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer()); + gemm.SetDeviceKernelArgs(&argument, gemm_kernel_args_dev.GetDeviceBuffer()); gemm.SetKBatch(argument, config.k_batch); invoker.Run(argument, StreamConfig{nullptr, false}); diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp index 1a2bcfb33edd5d87d692a5aedd99ae1b0edfac2a..6806bd1886d666f3de2009af07059537bb84b68c 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -157,8 +157,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: - a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); } } @@ -239,7 +239,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co "not support this GEMM problem"); } - gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + gemm.SetDeviceKernelArgs(&argument, gemm_arg_dev_mem.GetDeviceBuffer()); gemm.SetKBatch(argument, config.k_batch); invoker.Run(argument, StreamConfig{nullptr, false}); diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp index 0a63a29843aa319a38b3673db702952e06d8d851..8418c10f5ebaaec66c61f095f408607957d172f1 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -158,8 +158,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: - a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); } } @@ -240,7 +240,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co "not support this GEMM problem"); } - gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + gemm.SetDeviceKernelArgs(&argument, gemm_arg_dev_mem.GetDeviceBuffer()); gemm.SetKBatch(argument, config.k_batch); invoker.Run(argument, StreamConfig{nullptr, false}); diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 320870e0de7cab9dee43cab6542f13ca5341a90b..64125cd1d0183a6764afbdb89ca2e7581a4e8f8b 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + #pragma once struct ProblemSize final @@ -124,8 +127,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: - a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); } } @@ -168,9 +171,23 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co auto argument = gemm.MakeArgument( p_a, p_b, p_Ds, p_c, gemm_descs, a_element_op, b_element_op, c_element_op); - DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument)); + std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument); + std::size_t kargs_size = gemm.GetDeviceKernelArgSize(&argument); + + DeviceMem gemm_workspace, gemm_kargs; - gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer()); + // The following is necessary since TwoStage kernel is using additional memory both + // for Workspace and kernel arguments. + if(kargs_size > 0) + { + gemm_kargs.Realloc(kargs_size); + gemm.SetDeviceKernelArgs(&argument, gemm_kargs.GetDeviceBuffer()); + } + if(workspace_size > 0 && workspace_size != kargs_size) + { + gemm_workspace.Realloc(workspace_size); + gemm.SetWorkSpacePointer(&argument, gemm_workspace.GetDeviceBuffer()); + } if(!gemm.IsSupportedArgument(argument)) { diff --git a/example/18_batched_gemm_reduce/CMakeLists.txt b/example/18_batched_gemm_reduce/CMakeLists.txt index 94ed129dc03664043b1a0295f9f1dcfb38a77002..018b57f82c5d23a53eb4aabda785065e8041686d 100644 --- a/example/18_batched_gemm_reduce/CMakeLists.txt +++ b/example/18_batched_gemm_reduce/CMakeLists.txt @@ -1,4 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp b/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp index 90d80f9f034b391f75c498f3a34232edf64f5260..277fea0272a05cd1728ac536ecfecccd358a9d3d 100644 --- a/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -175,8 +175,8 @@ int main(int argc, char* argv[]) b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: - a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential{}); } c0_n_bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}); diff --git a/example/24_batched_gemm/CMakeLists.txt b/example/24_batched_gemm/CMakeLists.txt index 720af39af645e622cba1897a46fb1f7004516dae..d5157209449ba15cb9956bd9c04c78ef36b9fc27 100644 --- a/example/24_batched_gemm/CMakeLists.txt +++ b/example/24_batched_gemm/CMakeLists.txt @@ -22,3 +22,6 @@ if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp) add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4) endif() + +add_example_executable(example_batched_gemm_xdl_fp16int4_b_scale_v3 batched_gemm_xdl_fp16int4_b_scale_v3.cpp) +add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16int4_b_scale_v3) diff --git a/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp b/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp index fa8b752185c112160434d2e1b9866312dbbbda14..548500518fe90adbbb107ff4305db14269940622 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp @@ -78,14 +78,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD 2, // ABlockTransferSrcVectorDim 8, // ABlockTransferSrcScalarPerVector 8, // ABlockTransferDstScalarPerVector_AK1 - 1, // ABlockLdsExtraM + 0, // ABlockLdsExtraM S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder 2, // BBlockTransferSrcVectorDim 8, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferDstScalarPerVector_BK1 - 1, // BBlockLdsExtraN + 0, // BBlockLdsExtraN 1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock diff --git a/example/24_batched_gemm/batched_gemm_xdl_fp16int4_b_scale_v3.cpp b/example/24_batched_gemm/batched_gemm_xdl_fp16int4_b_scale_v3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..42171bcdb7f16d8368c1cdd259825db91bfed7eb --- /dev/null +++ b/example/24_batched_gemm/batched_gemm_xdl_fp16int4_b_scale_v3.cpp @@ -0,0 +1,82 @@ +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = ck::pk_i4_t; +using BScaleDataType = ck::half_t; +using AccDataType = F32; +using CShuffleDataType = F16; +using CDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto PermuteA = false; +static constexpr bool PermuteB = false; + +static constexpr ck::index_t Scale_Block_N = 1; +static constexpr ck::index_t Scale_Block_K = 128; + +static constexpr ck::index_t KPerBlock = 256; + +// clang-format off +using DeviceBatchedGemmV2Instance = + ck::tensor_operation::device::DeviceBatchedGemm_Xdl_CShuffleV3_BScale< + ALayout, BLayout, CLayout, + ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 256, Scale_Block_N, Scale_Block_K, + 16, 64, + KPerBlock, 8, 32, + 16, 16, + 1, 1, + S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 32, 32, 0, + 1, 1, S<1, 16, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, PermuteA, PermuteB>; +// clang-format on + +using ReferenceBatchedGemmInstance = ck::tensor_operation::host::ReferenceBatchedGemm; +#include "run_batched_gemm_example_fp16int4_b_scale.inc" + +int main(int argc, char* argv[]) { return !run_batched_gemm_fp16_int4_b_scale_example(argc, argv); } diff --git a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc new file mode 100644 index 0000000000000000000000000000000000000000..8c4913dbccd996d322e6a7eda7d736cfcbefe281 --- /dev/null +++ b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc @@ -0,0 +1,578 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#include + +#pragma once +struct ProblemSize final +{ + ck::index_t M = 128; + ck::index_t N = 128; + ck::index_t K = 384; + + ck::index_t stride_A = K; + ck::index_t stride_B = K; + ck::index_t stride_C = N; + + ck::index_t batch_stride_A = M * K; + ck::index_t batch_stride_B = K * N; + ck::index_t batch_stride_C = M * N; + + // Batched Gemm count + ck::index_t batch_count = 2; + + // Split K count + ck::index_t KBatch = 1; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; +}; + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto& [M, + N, + K, + stride_A, + stride_B, + stride_C, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_count, + KBatch] = problem_size; + + auto f_host_tensor_descriptor = [](std::size_t batch_count_, + std::size_t row, + std::size_t col, + std::size_t stride, + std::size_t batch_stride, + auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, stride, 1_uz}); + } + else + { + return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, 1_uz, stride}); + } + }; + + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + ck::index_t batch_BScale_Stride = + ((K + Scale_Block_K - 1) / Scale_Block_K) * ((N + Scale_Block_N - 1) / Scale_Block_N); + + Tensor a_g_m_k( + f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, ALayout{})); + Tensor b_g_k_n( + f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{})); + Tensor b_g_k_n_permute( + f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{})); + Tensor b1_g_k_n( + f_host_tensor_descriptor(batch_count, + (K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N, + Scale_Stride_BN, + batch_BScale_Stride, + BLayout{})); + + switch(config.init_method) + { + case 0: + a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_g_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_g_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_g_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 3: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_g_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 4: + a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_g_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 5: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_g_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.5, 0.5}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_g_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + } + + Tensor c_g_m_n_host_result( + f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{})); + Tensor c_g_m_n_device_result( + f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{})); + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; + std::cout << "b1_g_k_n: " << b1_g_k_n.mDesc << std::endl; + std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl; + + DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_g_k_n_device_buf(sizeof(BDataType) * b_g_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem b1_g_scale_device_buf(sizeof(BScaleDataType) * b1_g_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_g_m_n_device_buf(sizeof(CDataType) * + c_g_m_n_device_result.mDesc.GetElementSpaceSize()); + + printf("a_g_m_k size: %zu, b_g_k_n size: %zu, b1_g_k_n size: %zu, c_g_m_n size: %zu\n", + a_g_m_k.mDesc.GetElementSpaceSize(), + b_g_k_n_permute.mDesc.GetElementSpaceSize(), + b1_g_k_n.mDesc.GetElementSpaceSize(), + c_g_m_n_device_result.mDesc.GetElementSpaceSize()); + + // weight permute + if constexpr(PermuteB) + { + printf("Permute B\n"); + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int bs = 0; bs < batch_count; bs++) + { + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_g_k_n_permute(bs * batch_stride_B + j * N * K1 + i * K1 + jj) = + b_g_k_n(bs * batch_stride_B + i * K + (j * K1 + jj)); + } + } + } + } + } + else + { + b_g_k_n_permute = b_g_k_n; + } + + // vector pk_i4x4 permute + for(int bs = 0; bs < batch_count; bs++) + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int input[8]; + + for(int k = 0; k < 4; k++) + { + int i4x2 = b_g_k_n_permute(bs, j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int hi = input[2]; + int lo = input[0]; + int i4x2 = (hi << 4) | lo; + + b_g_k_n_permute(bs, j + 0, i) = i4x2; + } + + { + int hi = input[6]; + int lo = input[4]; + int i4x2 = (hi << 4) | lo; + + b_g_k_n_permute(bs, j + 2, i) = i4x2; + } + + { + int hi = input[3]; + int lo = input[1]; + int i4x2 = (hi << 4) | lo; + + b_g_k_n_permute(bs, j + 4, i) = i4x2; + } + + { + int hi = input[7]; + int lo = input[5]; + int i4x2 = (hi << 4) | lo; + + b_g_k_n_permute(bs, j + 6, i) = i4x2; + } + } + } + } + + a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data()); + b_g_k_n_device_buf.ToDevice(b_g_k_n_permute.mData.data()); + b1_g_scale_device_buf.ToDevice(b1_g_k_n.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceBatchedGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = + gemm.MakeArgument(static_cast(a_g_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_g_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_g_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + stride_A, + stride_B, + stride_C, + Scale_Stride_BN, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_BScale_Stride, + static_cast(b1_g_scale_device_buf.GetDeviceBuffer()), + batch_count, // batch count + KBatch, // split K count + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + bool pass = true; + Tensor b_g_k_n_dequant({batch_count, K, N}); + if(config.do_verification) + { + float v_b = 0; + for(int bs = 0; bs < batch_count; bs++) + { + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + ck::pk_i4_t i4x2 = b_g_k_n(bs, k, n).data; + int8_t i4 = 0; + if(k % 2 == 1) + i4 = (i4x2.data >> 0) & 0xf; + else + i4 = (i4x2.data >> 4) & 0xf; + i4 = i4 - 8; + v_b = ck::type_convert(i4); + + b_g_k_n_dequant(bs, k, n) = + ck::type_convert(v_b) * + ck::type_convert(b1_g_k_n(bs, k / Scale_Block_K, n / Scale_Block_N)); + } + } + } + + auto ref_gemm = ReferenceBatchedGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_g_m_k, + b_g_k_n_dequant, + c_g_m_n_host_result, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + hip_check_error(hipDeviceSynchronize()); + + c_g_m_n_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_g_m_n_device_result, + c_g_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + +#if 0 + // print A matrix + printf("A matrix:\n"); + for(int bs = 0; bs < batch_count; bs++) + { + printf("batch %d -> Address: %p\n", bs, static_cast(&a_g_m_k(bs, 0, 0))); + for(int i = 0; i < M; i++) + { + for(int j = 0; j < K; j++) + { + printf("%.2f,", static_cast(a_g_m_k(bs, i, j))); + } + printf("\n"); + } + } + + // print B matrix original + printf("B matrix original:\n"); + for(int bs = 0; bs < batch_count; bs++) + { + printf("batch %d -> Address: %p\n", bs, static_cast(&b_g_k_n(bs, 0, 0))); + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + ck::pk_i4_t i4x2 = b_g_k_n(bs, k, n).data; + int8_t i4 = 0; + if(k % 2 == 1) + i4 = (i4x2.data >> 0) & 0xf; + else + i4 = (i4x2.data >> 4) & 0xf; + i4 = i4 - 8; + printf("%d,", static_cast(i4)); + } + printf("\n"); + } + } + + // print B matrix + printf("B matrix:\n"); + for(int bs = 0; bs < batch_count; bs++) + { + printf("batch %d -> Address: %p\n", bs, static_cast(&b_g_k_n_dequant(bs, 0, 0))); + for(int i = 0; i < K; i++) + { + for(int j = 0; j < N; j++) + { + printf("%.2f, ", static_cast(b_g_k_n_dequant(bs, i, j))); + } + printf("\n"); + } + } + + // print B scale matrix + printf("B Scale matrix:\n"); + for(int bs = 0; bs < batch_count; bs++) + { + printf("batch %d -> Address: %p\n", bs, static_cast(&b1_g_k_n(bs, 0, 0))); + for(int i = 0; i < (K + Scale_Block_K - 1) / Scale_Block_K; i++) + { + for(int j = 0; j < (N + Scale_Block_N - 1) / Scale_Block_N; j++) + { + printf("%.2f, ", static_cast(b1_g_k_n(bs, i, j))); + } + printf("\n"); + } + } + + // print C matrix + printf("C matrix:\n"); + for(int bs = 0; bs < batch_count; bs++) + { + printf( + "batch %d -> Address: %p\n", bs, static_cast(&c_g_m_n_device_result(bs, 0, 0))); + for(int i = 0; i < M; i++) + { + for(int j = 0; j < N; j++) + { + printf("%.2f, ", static_cast(c_g_m_n_device_result(bs, i, j))); + } + printf("\n"); + } + } + + printf("C reference matrix:\n"); + for(int bs = 0; bs < batch_count; bs++) + { + printf("batch %d -> Address: %p\n", bs, static_cast(&c_g_m_n_host_result(bs, 0, 0))); + for(int i = 0; i < M; i++) + { + for(int j = 0; j < N; j++) + { + printf("%.2f, ", static_cast(c_g_m_n_host_result(bs, i, j))); + } + printf("\n"); + } + } +#endif + + return pass; +} + +bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + std::mt19937 gen(11939); + std::uniform_int_distribution dis(0, 15); + + problem_size.M = 128 * (dis(gen) + 1); + problem_size.N = 128 * (dis(gen) + 1); + problem_size.K = 256 * (dis(gen) + 2); + + problem_size.batch_count = 2; + + if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc >= 7) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + + if(argc >= 8) + { + problem_size.batch_count = std::stoi(argv[7]); + } + + if(argc >= 9) + { + problem_size.KBatch = std::stoi(argv[8]); + } + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + exit(0); + } + + problem_size.stride_A = problem_size.K; + problem_size.stride_B = problem_size.K; + problem_size.stride_C = problem_size.N; + + problem_size.batch_stride_A = problem_size.M * problem_size.K; + problem_size.batch_stride_B = problem_size.K * problem_size.N; + problem_size.batch_stride_C = problem_size.M * problem_size.N; + + return run_batched_gemm(problem_size, config); +} diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc index e3370b880bb29e78a88e5ed4ca1fcdcd14afa2f8..ce42a20be78940a9d6203ac9e8888d4b0bfe4910 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc @@ -32,6 +32,56 @@ using BiasLayout = typename LayoutSettingSelector::BiasLayout; template using ResidualLayout = typename LayoutSettingSelector::ResidualLayout; +#if defined(CK_USE_AMD_MFMA_GFX950) +template +using DeviceConvFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InputLayout, + WeightLayout, + ck::Tuple, ResidualLayout>, + OutputLayout, + InKernelDataType, + WeiKernelDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + OutKernelDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 64, // KPerBlock + 16, // AK1 + 16, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 4, // ABlockTransferSrcScalarPerVector + 4, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 4, // BBlockTransferSrcScalarPerVector + 4, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 16, 1, 16>, + 4>; +#else // defined(CK_USE_AMD_MFMA_GFX950) template using DeviceConvFwdInstance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< @@ -80,6 +130,7 @@ using DeviceConvFwdInstance = 1, S<1, 16, 1, 16>, 4>; +#endif // defined(CK_USE_AMD_MFMA_GFX950) template using HostConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd{1}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential{}); b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc index 27602e2313f7aa197e88e1fabeb39245e2fdf5eb..1514fc48b3cbc66d8c141114a169f893f74b7f9c 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { @@ -157,7 +157,7 @@ int run(int argc, char* argv[]) break; default: a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential{}); b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc index fa76faea84e4551ddf8d0617c132dbe7a6045fb3..2b02069e659056d6e1f9f1b4f8403b50b62df33f 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { @@ -118,7 +118,7 @@ int run(int argc, char* argv[]) b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); break; default: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc index 2e77479bccad137b00728b8280d8fbc6ea0f37f0..e0ccb6dad15dbf04c881afd717a2cde372b8e24a 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { @@ -153,7 +153,7 @@ int run(int argc, char* argv[]) b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; default: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc index 9ff4c56e0695aaf02fc88661722d7f410a928784..0ad031cc7127ec9ee3761681413191404b157b26 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { @@ -178,7 +178,7 @@ int run(int argc, char* argv[]) b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; default: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc index ea1e2734a684b61a363eb93ef0e2ff933f900ea5..cdfd86dff44cb78261221188663497a2392b889f 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { @@ -152,7 +152,7 @@ int run(int argc, char* argv[]) break; default: a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{1}); - b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc index 609d085299e62b5175c7af471b1ce87d75d4217d..7ac29f33ca61439a135d038d84e3f1061833d46b 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { @@ -156,7 +156,7 @@ int run(int argc, char* argv[]) b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; default: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc index b05915c07fb681ef2e1693da9ab68e4b925e2731..fb9b1b0bd7283eb5124ab1e47c2c828d264a7063 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { @@ -156,7 +156,7 @@ int run(int argc, char* argv[]) b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; default: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc index 3fdaaebb0f5ff90c8399c65473773596da586ffb..2cb69380e50c80831c3ddb6287d8c96fa0bd9822 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { @@ -173,7 +173,7 @@ int run(int argc, char* argv[]) b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; default: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/35_splitK_gemm/run_splitK_gemm_example.inc b/example/35_splitK_gemm/run_splitK_gemm_example.inc index e3690984abc08cff2e5a450f9f6b6ce13217b883..cb1d3410c986c770b16bfb20eaace4bebd6ee289 100644 --- a/example/35_splitK_gemm/run_splitK_gemm_example.inc +++ b/example/35_splitK_gemm/run_splitK_gemm_example.inc @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + #pragma once struct ProblemSize final @@ -66,8 +69,8 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: - a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential{}); } DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); diff --git a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp index ff1282f3c78d9e227cc8a2132b1cd81fa7c27187..f27dc60541acd5cde7614734c1fad384aec9e280 100644 --- a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp +++ b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp @@ -377,7 +377,7 @@ int main(int argc, char* argv[]) break; default: a0_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential{}); d00_g_m_n.GenerateTensorValue(GeneratorTensor_1{1}); d01_g_m_n.GenerateTensorValue(GeneratorTensor_1{1}); b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); diff --git a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp index 8a0474156ccb93bc01323d52d87fc3099e9ab332..6af8ac6488f93e9b951419da24f149557de362bf 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -41,7 +41,7 @@ struct ExecutionConfig final { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; }; #define DefaultConvParams \ diff --git a/example/41_grouped_conv_conv_fwd/CMakeLists.txt b/example/41_grouped_conv_conv_fwd/CMakeLists.txt index 8ab56b21a638c518d6d2584a6ffb3b94a6130470..c5c5a84b67affa5ff0762b5f4da0be2ddaf75842 100644 --- a/example/41_grouped_conv_conv_fwd/CMakeLists.txt +++ b/example/41_grouped_conv_conv_fwd/CMakeLists.txt @@ -5,6 +5,6 @@ if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp) endif(USE_BITINT_EXTENSION_INT4) -if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") +if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx95" AND NOT GPU_TARGETS MATCHES "gfx1") add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp) endif() diff --git a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp index a90a6340a431a55c271dca3d3d0d1771382218f5..392cb155cb8f5632ca6a1fa16179edfb58a17ee2 100644 --- a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp +++ b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -248,7 +248,7 @@ int main(int argc, char* argv[]) d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1{1}); break; default: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1{1}); diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp index 742fd5547a97698f363816b6833e1621bce92d45..055d253042ab53e0c7c9bd7a2a6c8ed95c758a5f 100644 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -194,9 +194,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b1_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: - a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); - b1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); } d0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp index 809c1a956cce8dcfd28623d56e3be86075acfd28..1ba8133ea71b459265f44d955843149e0fb5c042 100644 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp @@ -184,9 +184,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: - a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - a1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + a1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); } d0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); diff --git a/example/62_convnd_activ/binary/CMakeLists.txt b/example/62_convnd_activ/binary/CMakeLists.txt index 9d90cdd244e75d1695b35c0bd367e457816a200b..7c09177049ac2c2f667b3fd80afc1cea882883f4 100644 --- a/example/62_convnd_activ/binary/CMakeLists.txt +++ b/example/62_convnd_activ/binary/CMakeLists.txt @@ -1,4 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/62_convnd_activ/convinvscale/CMakeLists.txt b/example/62_convnd_activ/convinvscale/CMakeLists.txt index 07f42075bd0a8a211f1c72cd7e60c6673469f846..6eb7fb8ece1dd03812d1e842b7ad955c2a72ebe2 100644 --- a/example/62_convnd_activ/convinvscale/CMakeLists.txt +++ b/example/62_convnd_activ/convinvscale/CMakeLists.txt @@ -1,4 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/62_convnd_activ/convscale/CMakeLists.txt b/example/62_convnd_activ/convscale/CMakeLists.txt index 9264da24a69896f309796154084098d61d5e43ac..a52818e21e7fbe6e190cc5a62b7278965bf65010 100644 --- a/example/62_convnd_activ/convscale/CMakeLists.txt +++ b/example/62_convnd_activ/convscale/CMakeLists.txt @@ -1,4 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp b/example/62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp index 978221f8e1093485cec0e98b140d801c1ca01335..bf560f8a4347a7eb505880bc1aa2f28bce7e923e 100644 --- a/example/62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp +++ b/example/62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp @@ -172,12 +172,13 @@ bool run_grouped_conv_fwd(bool do_verification, { case 0: break; case 1: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + // values generated: -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5 + in.GenerateTensorValue(GeneratorTensor_2{-5, 6}); + wei.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); break; default: - in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + in.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); } DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); diff --git a/example/62_convnd_activ/convscale_add/CMakeLists.txt b/example/62_convnd_activ/convscale_add/CMakeLists.txt index 40cfd74aa4570b5fd3bad0f8ffb96824b97996e4..f8bc13c8f78b38d6f202c7f9e7bbd6cdea6eddd3 100644 --- a/example/62_convnd_activ/convscale_add/CMakeLists.txt +++ b/example/62_convnd_activ/convscale_add/CMakeLists.txt @@ -1,4 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/62_convnd_activ/convscale_reduce/CMakeLists.txt b/example/62_convnd_activ/convscale_reduce/CMakeLists.txt index ff9020a70706ef1eb1d142211ff2052a486db266..a794d68bb6ad8ea060ade834a1b194d30ee9dd7f 100644 --- a/example/62_convnd_activ/convscale_reduce/CMakeLists.txt +++ b/example/62_convnd_activ/convscale_reduce/CMakeLists.txt @@ -1,4 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/62_convnd_activ/convscale_relu/CMakeLists.txt b/example/62_convnd_activ/convscale_relu/CMakeLists.txt index 95589cedcb406d55b9e91de6d7cdb1a06a87a01d..a348e30a971fa8249a016e2076d6bd364858702c 100644 --- a/example/62_convnd_activ/convscale_relu/CMakeLists.txt +++ b/example/62_convnd_activ/convscale_relu/CMakeLists.txt @@ -1,4 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/62_convnd_activ/dynamic_unary/CMakeLists.txt b/example/62_convnd_activ/dynamic_unary/CMakeLists.txt index 23f07439a56f21090204248036f95f31de73d6a0..21613b1ab371851ef8b4283328d4aa827789f2d5 100644 --- a/example/62_convnd_activ/dynamic_unary/CMakeLists.txt +++ b/example/62_convnd_activ/dynamic_unary/CMakeLists.txt @@ -1,4 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/62_convnd_activ/multi_AB/CMakeLists.txt b/example/62_convnd_activ/multi_AB/CMakeLists.txt index c89c82d384c68a9676025421d04d542e28725710..1c865d4c9582bfca93a38d997e74d4d844319cb8 100644 --- a/example/62_convnd_activ/multi_AB/CMakeLists.txt +++ b/example/62_convnd_activ/multi_AB/CMakeLists.txt @@ -1,4 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/62_convnd_activ/unary/CMakeLists.txt b/example/62_convnd_activ/unary/CMakeLists.txt index 3470e9b9456f7361e1a05c07e99b00fe90990a31..927b2e334164309fde8c14ac5bc76f76a6ee5e4d 100644 --- a/example/62_convnd_activ/unary/CMakeLists.txt +++ b/example/62_convnd_activ/unary/CMakeLists.txt @@ -1,4 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp index 2568754648d6a04a3066921c0f2ea2e3ee967962..9b7849a6543e596ed8d0b43adf2f88825472b4ba 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp @@ -205,7 +205,6 @@ int main(int argc, char* argv[]) a1_device_buf.ToDevice(a1_m_k.mData.data()); b0_device_buf.ToDevice(b0_k_n.mData.data()); b1_device_buf.ToDevice(b1_k_n.mData.data()); - e_device_buf.ToDevice(e_m_n_device_result.mData.data()); auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; @@ -253,8 +252,6 @@ int main(int argc, char* argv[]) std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" << std::endl; - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - if(do_verification) { Tensor c_m_n({M, N}); diff --git a/example/67_gemm_microscaling/CMakeLists.txt b/example/67_gemm_microscaling/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..93770684df2fbc35480df8f2c838c9aca5db4155 --- /dev/null +++ b/example/67_gemm_microscaling/CMakeLists.txt @@ -0,0 +1,5 @@ +add_custom_target(example_gemm_mx) + +add_example_executable(example_gemm_mx_fp8 gemm_mx_fp8.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp8) + diff --git a/example/67_gemm_microscaling/README.md b/example/67_gemm_microscaling/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c0a0972db6a9ca6e72b20c49460fda9e77911167 --- /dev/null +++ b/example/67_gemm_microscaling/README.md @@ -0,0 +1,17 @@ +# GEMM Examples for Microscaling Formats + +## example_gemm_mx_fp8 + +```bash +# arg1: verification (0=no, 1=CPU) +# arg2: initialization (0=no init, 1=integer value, 2=decimal value) +# arg3: time kernel (0=no, 1=yes) +# arg4: verbosity (0=no info, 1=verbose info) +# arg5 to 10: M (16x), N(16x), K(16x), StrideA, StrideB, StrideC +./bin/example_gemm_mx_fp8 1 1 0 1 +``` + +```bash +# Implies: ./bin/example_gemm_mx_fp8 1 2 0 0 +./bin/example_gemm_mx_fp8 +``` \ No newline at end of file diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5b00b5a123b4c976035e1ea6ada728b1f7f71950 --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -0,0 +1,415 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" + +using ScaleDataType = ck::e8m0_bexp_t; + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +struct ExecutionConfig final +{ + int do_verification = 1; // (0=no, 1=CPU) + int init_method = 2; // (0=no init, 1=integer value, 2=decimal value) + bool time_kernel = false; // (0=no, 1=yes) + int verbosity = 0; // (0=no info, 1=verbose info) +}; + +struct ProblemSize final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = -1; + ck::index_t StrideB = -1; + ck::index_t StrideC = -1; +}; + +bool parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config) +{ + if(argc == 1) + { + // use default case + } + else if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.verbosity = std::stoi(argv[4]); + } + else if(argc == 11) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.verbosity = std::stoi(argv[4]); + + problem_size.M = std::stoi(argv[5]); + problem_size.N = std::stoi(argv[6]); + problem_size.K = std::stoi(argv[7]); + + problem_size.StrideA = std::stoi(argv[8]); + problem_size.StrideB = std::stoi(argv[9]); + problem_size.StrideC = std::stoi(argv[10]); + } + else + { + std::cerr << "arg1: verification (0=no, 1=CPU)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" + << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4: verbosity (0=no info, 1=verbose info)" << std::endl + << "arg5 to 10: M (16x), N(16x), K(16x), StrideA, StrideB, StrideC" << std::endl; + return false; + } + + return true; +} + +template +bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using ELayout = CLayout; + using DsLayout = ck::Tuple<>; + using DsDataType = ck::Tuple<>; + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CDEElementOp = CElementWiseOp; + + static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + static constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; + static constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; + +#if 1 + // XXX: These parameters should not exist in MX-native GEMM kernel + static constexpr ck::index_t Scale_Block_M = 128; + static constexpr ck::index_t Scale_Block_N = 128; +#endif + static constexpr ck::index_t Scale_Block_K = MXVectorSize; + + // XXX: DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 is not designed to utilize MX-specific MFMA + // instructions. + // + // XXX: DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 is not designed to utilize device-optimized + // scaled type convert functions. + // + // XXX: In DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3, KPerBlock is expected to be equal to + // ScaleBlockK (aka MXVectorSize). + // Additionally, the following is also expected: + // static_assert(ScaleBlockM % MPerBlock == 0); + // static_assert(ScaleBlockN % NPerBlock == 0); + // In MX-native GEMM kernel these requirements should be relaxed. + // + // XXX: It appears, by default we are using mfma_f32_16x16x4xf32 + // MfmaSelector::selected_mfma.k_per_blk = + // MfmaSelector::selected_mfma.k_per_blk = mfma_f32_16x16x4xf32 + // XXX: GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 assumes scale type is float + + // clang-format off + using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 + // ######| ALayout| BLayout| DsLayout| CLayout| ADataType| AScale| BDataType| BScale| DsDataType| CDataType| GemmAcc| CShuffleDataType|AElementwise|BElementwise| CElementwise| GemmSpec|Block| ScaleBlockM| ScaleBlockN| ScaleBlockK| M| N| K| AK1| BK1| M| N|MXdl|NXdl|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer| ABlock|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer| BBlock| CShuffle| CShuffle|CShuffleBlockTransfer|CDEShuffleBlockTransfer| BlkGemm| BlkGemm|ComputeTypeA|ComputeTypeB|LDSTypeA|LDSTypeB| + // ######| | | | | | DataType| | DataType| | | DataType| | Operation| Operation| Operation| | Size| | | | Per| Per| Per| | | Per| Per| Per| Per| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar|LdsExtraM| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVector| SrcScalar| DstScalar|LdsExtraN| MXdl| NXdl| ClusterLengths| Scalar| PipeSched| PipelineVer| | | | | + // ######| | | | | | | | | | | | | | | | | | | | |Block|Block| Block| | | XDL| XDL|Wave|Wave| Lengths| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths| ArrangeOrder| | Dim| PerVector| PerVector_BK1| | PerWave| PerWave| MBlock_MPerBlock| PerVectors| | | | | | | + // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | AK0_M_AK1| | | | | | | BK0_N_BK1| | | | | |PerShuffle|PerShuffle| NBlock_NPerBlock| | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, XDataType, BDataType, XDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlkGemmPSched, BlkGemmPVer, float, float, float, float>; + // clang-format on + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + + auto f_host_tensor_descriptor = + [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1}); + } + else + { + return HostTensorDescriptor({row, col}, {1, stride}); + } + }; + + auto f_get_default_stride = + [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + if(K % Scale_Block_K != 0) + { + throw std::runtime_error("wrong! K must be multiple of Scale_Block_K (16 or 32)"); + }; + + auto Scale_Stride_AM = f_get_default_stride(M, K / Scale_Block_K, StrideA, ALayout{}); + auto Scale_Stride_BN = f_get_default_stride(K / Scale_Block_K, N, StrideB, BLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + Tensor a_m_k_scale( + f_host_tensor_descriptor(M, K / Scale_Block_K, Scale_Stride_AM, ALayout{})); // scales for A + Tensor b_k_n_scale( + f_host_tensor_descriptor(K / Scale_Block_K, N, Scale_Stride_BN, BLayout{})); // scales for B + + Tensor c_m_n_host_result( + f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // host verification + Tensor c_m_n_device_result( + f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // device result downloaded to host + + if(config.verbosity >= 0) + { + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl; + std::cout << "c_m_n_device_result: " << c_m_n_device_result.mDesc << std::endl; + } + + switch(config.init_method) + { + case 0: + if(config.verbosity > 0) + { + std::cout << "NOTE: No input data initialization." << std::endl; + } + break; + case 1: + case 2: + ck::utils::FillConstant{ck::type_convert(1.0f)}(a_m_k); + ck::utils::FillConstant{ck::type_convert(0.5f)}(a_m_k_scale); + ck::utils::FillConstant{ck::type_convert(1.0f)}(b_k_n); + ck::utils::FillConstant{ck::type_convert(2.0f)}(b_k_n_scale); + if(config.verbosity > 0) + { + std::cout << "Init A = {1}" << std::endl; + std::cout << "Init A scale = {0.5}" << std::endl; + std::cout << "Init B = {1}" << std::endl; + std::cout << "Init B scale = {2.0}" << std::endl; + std::cout << "Expect C = {K}" << std::endl; + } + break; + + default: + if(config.verbosity > 0) + { + std::cout << "NOTE: No input data initialization." << std::endl; + } + } + + if(config.verbosity > 0) + std::cout << "Device memory allocation..." << std::endl; + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem a_scale_device_buf(sizeof(XDataType) * a_m_k_scale.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b_scale_device_buf(sizeof(XDataType) * b_k_n_scale.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + if(config.verbosity > 0) + std::cout << "Upload data to device..." << std::endl; + a_device_buf.ToDevice(a_m_k.mData.data()); + a_scale_device_buf.ToDevice(a_m_k_scale.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + b_scale_device_buf.ToDevice(b_k_n_scale.mData.data()); + if(config.verbosity > 0) + std::cout << "Done." << std::endl; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{}, + c_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{}, + StrideC, + a_scale_device_buf.GetDeviceBuffer(), + b_scale_device_buf.GetDeviceBuffer(), + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error("wrong!\n" + "Provided combination of compilation and runtime parameters is " + "not consistent with the supported device_gemm arguments."); + } + + if(config.verbosity > 0) + std::cout << "Computing GEMM on device..." << std::endl; + float ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, config.verbosity, 20, 50}); + + bool res_verified = true; + if(config.do_verification > 0) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + if(config.verbosity > 0) + { + std::cout << "Done." << std::endl; + std::cout << "Computing GEMM on host..." << std::endl; + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMXGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k, + a_m_k_scale, + b_k_n, + b_k_n_scale, + c_m_n_host_result, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + if(config.verbosity > 0) + { + std::cout << "Done." << std::endl; + std::cout << "Comparing results..." << std::endl; + } + + if(config.init_method == 1) + { + res_verified = + res_verified && std::abs(static_cast(K) - c_m_n_device_result(0, 0)) <= 0.0f; + std::cout << "Expected vs Computed: " << 1.0f * K << " vs " << c_m_n_device_result(0, 0) + << ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl; + } + + res_verified = res_verified && ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!"); + + if(config.verbosity > 0 && res_verified) + std::cout << "Done." << std::endl; + } + else + { + if(config.verbosity > 0) + std::cout << "Done." << std::endl; + } + + if(config.time_kernel) + { + std::size_t flop = std::size_t(2) * M * N * K + M * K + K * N; // GEMM + A scale + B scale + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N + + sizeof(XDataType) * (M * K + K * N) / Scale_Block_K; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + } + + return res_verified; +} + +template +bool run_mx_gemm_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && + run_mx_gemm(problem_size, config); +} diff --git a/example/67_gemm_microscaling/gemm_mx_fp8.cpp b/example/67_gemm_microscaling/gemm_mx_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d2e21698ec41d5b4432e0a38b6d27ea5f223751e --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_fp8.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_mx_common.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; +#if 1 +// XXX: MX-native GEMM kernel will work with e8m0_bexp_t scale type +using XDataType = float; +#else +using XDataType = ck::e8m0_bexp_t; +#endif +using AccDataType = float; +using CShuffleDataType = float; +using CDataType = float; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t mx_vector_size = 128; // scaling block size + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index ea739c707153dd43c5a0c130e8fa916973e83b90..bcb62df62570bded896594e65ff4cc081b7ed12e 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -5,6 +5,14 @@ include_directories(BEFORE add_custom_target(examples) + +# list of examples that are labelled as REGRESSION_EXAMPLE for make regression (runtime more than 30 seconds) +# all other tests are labelled as SMOKE_EXAMPLE +set(REGRESSION_EXAMPLES + example_sparse_embedding3_forward_layernorm +) + + function(add_example_dependencies EXAMPLE_NAME FILE_NAME) if(FILE_NAME) add_dependencies(EXAMPLE_NAME FILE_NAME) @@ -15,34 +23,34 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) message("adding example ${EXAMPLE_NAME}") set(result 1) if(DEFINED DTYPES) - foreach(source IN LISTS FILE_NAME) - set(test 0) - if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) - set(test 1) - endif() - if(test EQUAL 1) - message("removing example source file ${source} ") - list(REMOVE_ITEM FILE_NAME "${source}") - endif() - endforeach() + foreach(source IN LISTS FILE_NAME) + set(test 0) + if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) + set(test 1) + endif() + if(test EQUAL 1) + message("removing example source file ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() endif() set(EX_TARGETS ${SUPPORTED_GPU_TARGETS}) @@ -54,6 +62,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() + #Do not build any DPP examples if DPP_KERNELS not set + foreach(source IN LISTS FILE_NAME) + if(NOT DEFINED DPP_KERNELS AND source MATCHES "_dpp") + message("removing dpp example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() #Do not build any XDL examples if gfx9 targets are not on the list foreach(source IN LISTS FILE_NAME) if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") @@ -68,6 +83,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() + #Do not build any microscaling examples if gfx950 target is not on the list + foreach(source IN LISTS FILE_NAME) + if(NOT EX_TARGETS MATCHES "gfx950" AND source MATCHES "_mx") + message("removing microscaling example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() #Do not build any FP8 examples if CK_ENABLE_FP8 not set foreach(source IN LISTS FILE_NAME) if(NOT DEFINED CK_ENABLE_FP8 AND source MATCHES "_fp8") @@ -87,7 +109,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) if(FILE_NAME MATCHES "_xdl") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(FILE_NAME MATCHES "_wmma") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950) + elseif(FILE_NAME MATCHES "_mx") #only build mx example for gfx950 + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) @@ -100,6 +124,15 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) set(result 0) endif() #message("add_example returns ${result}") + if(result EQUAL 0 AND NOT "${EXAMPLE_NAME}" IN_LIST REGRESSION_EXAMPLES) + #message("adding to SMOKE EXAMPLE FILTER ${EXAMPLE_NAME}") + set_tests_properties(${EXAMPLE_NAME} PROPERTIES LABELS "SMOKE_TEST") + add_dependencies(smoke ${EXAMPLE_NAME}) + elseif(result EQUAL 0 AND "${EXAMPLE_NAME}" IN_LIST REGRESSION_EXAMPLES) + #message("Adding to REGRESSION EXAMPLE FILTER ${EXAMPLE_NAME}") + set_tests_properties(${EXAMPLE_NAME} PROPERTIES LABELS "REGRESSION_TEST") + add_dependencies(regression ${EXAMPLE_NAME}) + endif() set(result ${result} PARENT_SCOPE) endfunction(add_example_executable EXAMPLE_NAME) @@ -171,7 +204,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) if(FILE_NAME MATCHES "_xdl") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(FILE_NAME MATCHES "_wmma") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950) endif() set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) @@ -181,8 +214,10 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) set(result 0) endif() + #message("add_example returns ${result}") set(result ${result} PARENT_SCOPE) + endfunction(add_example_executable_no_testing EXAMPLE_NAME) # add all example subdir diff --git a/example/README.md b/example/README.md new file mode 100644 index 0000000000000000000000000000000000000000..43b3419f80183d6a091fbafc339f8e045fcfe991 --- /dev/null +++ b/example/README.md @@ -0,0 +1,2 @@ +[Back to the main page](../README.md) +# Composable Kernel examples \ No newline at end of file diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 1ba76a523eed58b60aa67b2a58b9af0825647f06..9ba3a453fc1ed0dc76ea8dd93b47de5002630019 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -102,6 +102,11 @@ else() list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0) endif() +# conditionally specify the use of OCP_FP8 +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + # Allow comparing floating points directly in order to check sentinel values list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal) diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index c7ab296c3bbedd5c1358e45cf28fb29b8786c6f5..e9806e7a67ad3782444baf5b7c6a41c8a639ab0f 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -15,8 +15,7 @@ This will result in an executable `build/bin/tile_example_fmha_fwd` ## kernel The kernel template is `fmha_fwd_kernel.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template. -There are 3 template parameters for this kernel template. -* `TilePartitioner` is used to map the workgroup to corresponding tile, `fmha_fwd_tile_partitioner.hpp` in this folder served as this purpose. +There are 2 template parameters for this kernel template. * `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck_tile/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)). * `EpiloguePipeline` will modify and store out the result in the last phase. People usually will do lot of post-fusion at this stage, so we also abstract this concept. Currently we didn't do much thing at the epilogue stage but leave the room for future possible support. diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 66691356ab7d882ff6d073c0e38c72cc416d6320..332707eafd126ca26bcdb7b47ee35d238978a880 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -2,10 +2,17 @@ # Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation -DTYPE_MAP = { - "fp16": "ck_tile::fp16_t", - "bf16": "ck_tile::bf16_t", - "fp8" : "ck_tile::fp8_t" +FWD_DTYPE_MAP = { + "fp16" : "FmhaFwdFp16", + "bf16" : "FmhaFwdBf16", + "fp8" : "FmhaFwdFp8", + "fp8fp16": "FmhaFwdFp8Fp16", + "fp8bf16": "FmhaFwdFp8Bf16" +} + +BWD_DTYPE_MAP = { + "fp16": "FmhaBwdFp16", + "bf16": "FmhaBwdBf16" } MASK_IMPL = { @@ -112,6 +119,7 @@ PIPELINE_MAP = { PIPELINE_ENUM_MAP = { "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", + "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", } BOOL_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 096394c0c95e8833051a6c5de06f49528eeb8ebf..83a1e82d6d23e57b522b59033609fa2dd6a84da9 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -283,7 +283,7 @@ class FmhaBwdApiPool: inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], - F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype], + F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_deterministic=BOOL_MAP[trait.deterministic]) @@ -360,7 +360,7 @@ class FmhaBwdDQDKDVKernel: FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format( F_idx = self.F_idx, F_hdim = self.F_hdim, - F_dtype = DTYPE_MAP[self.F_dtype], + F_dtype = BWD_DTYPE_MAP[self.F_dtype], F_bm0 = self.F_tile.F_bm0, F_bn0 = self.F_tile.F_bn0, F_bk0 = self.F_tile.F_bk0, @@ -469,7 +469,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> gen = list() api_pool = FmhaBwdApiPool(mask_impl) - for dtype in DTYPE_MAP.keys(): + for dtype in BWD_DTYPE_MAP.keys(): d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) if d == None: continue @@ -585,7 +585,7 @@ class FmhaBwdOGradDotOKernel: FMHA_BWD_DOT_DO_O_KERNEL_BODY.format( F_idx = self.F_idx, F_hdim = self.F_hdim, - F_dtype = DTYPE_MAP[self.F_dtype], + F_dtype = BWD_DTYPE_MAP[self.F_dtype], F_spad = BOOL_MAP[self.F_spad], F_dvpad = BOOL_MAP[self.F_dvpad], F_mode = MODE_MAP[self.F_mode], @@ -616,7 +616,7 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: gen = list() - for dtype in DTYPE_MAP.keys(): + for dtype in BWD_DTYPE_MAP.keys(): d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) if d == None: continue @@ -716,7 +716,7 @@ class FmhaBwdConvertQGradKernel: FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format( F_idx = self.F_idx, F_hdim = self.F_hdim, - F_dtype = DTYPE_MAP[self.F_dtype], + F_dtype = BWD_DTYPE_MAP[self.F_dtype], F_bm0 = self.F_bm0, F_bn0 = self.F_bn0, F_spad = BOOL_MAP[self.F_spad], @@ -751,7 +751,7 @@ def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]: gen = list() - for dtype in DTYPE_MAP.keys(): + for dtype in BWD_DTYPE_MAP.keys(): d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) if d == None: continue diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index e5ee1d22e74155a91df630494da6c1035f7fdf38..1c9d743f3da58d992ad96cc3a2270c7edfaf11aa 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -29,11 +29,6 @@ K0_MAX_SUBMAX_MAP = { 256: 256 } -TILE_PARTITIONER_MAP = { - "shb" : "ck_tile::FmhaFwdTilePartitioner_SHB", - "hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS", -} - FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n // auto generated by generate.py @@ -44,13 +39,12 @@ FMHA_FWD_KERNEL_BODY=""" using fmha_dtype_{F_idx} = {F_dtype}; using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; -using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, - fmha_warp_tile_{F_idx}, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, - fmha_warp_tile_{F_idx}, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, {F_vlayout}>; using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, @@ -91,9 +85,7 @@ using fmha_epilogue_{F_idx} = {F_spad}, {F_dvpad}>>; using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdKernel<{F_tile_partitioner}, - fmha_pipeline_{F_idx}, - fmha_epilogue_{F_idx}>; + ck_tile::FmhaFwdKernel; using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; @@ -282,7 +274,7 @@ class FmhaFwdApiPool: F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, - F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) if_j = 'if' if j == 0 else 'else if' per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' @@ -301,20 +293,24 @@ class FmhaFwdTileSize: F_bk1 : int # tile size along kv gemm unroll F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) F_rm0 : int # number of warps for gemm0 along q seqlen - F_rn0 : int # number of warps for gemm0 along k seqlen + F_rn0 : int # number of warps for gemm0 along k seqlen F_rk0 : int # number of warps for gemm0 along head dim q (not used) F_rm1 : int # number of warps for gemm1 along q seqlen F_rn1 : int # number of warps for gemm1 along head dim v F_rk1 : int # number of warps for gemm1 along k seqlen (not used) - F_wm : int # warp size along m (warp size) - F_wn : int # warp size along n - F_wk : int # warp size along k + F_wm0 : int # gemm0 warp size along m + F_wn0 : int # gemm0 warp size along n + F_wk0 : int # gemm0 warp size along k + F_wm1 : int # gemm1 warp size along m + F_wn1 : int # gemm1 warp size along n + F_wk1 : int # gemm1 warp size along k F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy @property def name(self) -> str: return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ - f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}" + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") @dataclass class FmhaFwdKernel: @@ -326,12 +322,6 @@ class FmhaFwdKernel: F_pipeline : FmhaFwdPipeline mask_impl : str - def get_tp(self) -> str: - if self.F_mode == 'group': - return 'hbs' - else: - return 'shb' - @property def template(self) -> str: kernel_body = str() @@ -339,7 +329,7 @@ class FmhaFwdKernel: FMHA_FWD_KERNEL_BODY.format( F_idx = self.F_idx, F_hdim = self.F_hdim, - F_dtype = DTYPE_MAP[self.F_dtype], + F_dtype = FWD_DTYPE_MAP[self.F_dtype], F_bm0 = self.F_tile.F_bm0, F_bn0 = self.F_tile.F_bn0, F_bk0 = self.F_tile.F_bk0, @@ -352,9 +342,12 @@ class FmhaFwdKernel: F_rm1 = self.F_tile.F_rm1, F_rn1 = self.F_tile.F_rn1, F_rk1 = self.F_tile.F_rk1, - F_wm = self.F_tile.F_wm, - F_wn = self.F_tile.F_wn, - F_wk = self.F_tile.F_wk, + F_wm0 = self.F_tile.F_wm0, + F_wn0 = self.F_tile.F_wn0, + F_wk0 = self.F_tile.F_wk0, + F_wm1 = self.F_tile.F_wm1, + F_wn1 = self.F_tile.F_wn1, + F_wk1 = self.F_tile.F_wk1, F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], F_spad = BOOL_MAP[self.F_pipeline.F_spad], F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], @@ -368,13 +361,12 @@ class FmhaFwdKernel: F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode = MODE_MAP[self.F_mode], - F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], - F_tile_partitioner = TILE_PARTITIONER_MAP[self.get_tp()]) + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag]) @property def name(self) -> str: # TODO: we don't encode idx here - return f"fmha_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_{self.get_tp()}_" + \ + return f"fmha_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ self.F_tile.name + '_' + self.F_pipeline.name @property @@ -409,17 +401,17 @@ class FmhaFwdKernel: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, -1), - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), - ## '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), + '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': return { - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1) + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), } else: return None @@ -462,6 +454,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm # no need lse/dropout kernels for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) + elif dtype in ['fp8fp16', 'fp8bf16']: + # TODO + None else: assert False return pipelines @@ -469,7 +464,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm gen = list() api_pool = FmhaFwdApiPool(mask_impl) - for dtype in DTYPE_MAP.keys(): + for dtype in FWD_DTYPE_MAP.keys(): d = get_fmha_fwd_tile_dict_from_dtype(dtype) if d == None: continue diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index cfd1d01c91a57e824a8ef7370af56190ccede4f5..2f20819302f85eb2ceec358177861e3782daac21 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -46,9 +46,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProbl using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline< fmha_pipeline_problem_{F_idx}>; -using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdAppendKVKernel, - fmha_pipeline_{F_idx}>; +using fmha_kernel_{F_idx} = ck_tile::FmhaFwdAppendKVKernel; using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>; @@ -181,7 +179,7 @@ class FmhaFwdAppendKVApiPool: inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope], F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) if_j = 'if' if j == 0 else 'else if' per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' @@ -216,7 +214,7 @@ class FmhaFwdAppendKVKernel: FMHA_FWD_APPENDKV_KERNEL_BODY.format( F_idx = self.F_idx, F_hdim = self.F_hdim, - F_dtype = DTYPE_MAP[self.F_dtype], + F_dtype = FWD_DTYPE_MAP[self.F_dtype], F_bs = self.F_tile.F_bs, F_bsk = self.F_tile.F_bsk, F_bd = self.F_tile.F_bd, @@ -301,6 +299,9 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> elif dtype in ['fp8', 'bf8']: # rope/paged-kv is not supported pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no', 'f')) + elif dtype in ['fp8fp16', 'fp8bf16']: + # TODO + None else: assert False return pipelines @@ -308,7 +309,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> gen = list() api_pool = FmhaFwdAppendKVApiPool(mask_impl) - for dtype in DTYPE_MAP.keys(): + for dtype in FWD_DTYPE_MAP.keys(): d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype) if d == None: continue @@ -352,4 +353,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im _, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n") \ No newline at end of file + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index b084e9d0fcde38d8786f19f769e731e5c0c16786..37745dd38299e30eb78a756e714f710102788f48 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -39,6 +39,7 @@ K0_MAX_SUBMAX_MAP = { FMHA_FWD_SPLITKV_PIPELINE_MAP = { "qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS", + "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS", "qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync", } @@ -47,16 +48,15 @@ using fmha_dtype_{F_idx} = {F_dtype}; using fmha_mask_{F_idx} = {F_mask}; namespace {{ -template -struct kernel_runner {{ +template +struct instance {{ using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; -using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; using fmha_shape = ck_tile::TileFmhaShape, - fmha_warp_tile, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, - fmha_warp_tile, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, {F_vlayout}>; using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, @@ -64,11 +64,12 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, {F_dpad}, {F_dvpad}, {F_bias}, - false, + /*kHasBiasGrad=*/false, {F_lse}, {F_squant}, {F_pagedkv}, kHasUnevenSplits, + kMergeNumHeadGroupsSeqLenQ, {F_occupancy}>; using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< @@ -96,9 +97,7 @@ using fmha_epilogue = {F_spad}, {F_dvpad}>>; using fmha_kernel = - ck_tile::FmhaFwdSplitKVKernel, - fmha_pipeline, - fmha_epilogue>; + ck_tile::FmhaFwdSplitKVKernel; static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ @@ -112,33 +111,55 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) }} using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, + {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; #include +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wtautological-compare" + +namespace {{ +template +void run_instance(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ + if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS + && (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask> + || std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{ + if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{ + instance::run(s, a); + }} else {{ + instance::run(s, a); + }} + }} else {{ + instance::run(s, a); + }} +}} +}} // anonymous namespace + +#pragma clang diagnostic pop + template<> void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ if constexpr({F_mode} == false) {{ // batch mode // we don't check every seqlen_k values for kvcache if (a.seqlen_k_ptr != nullptr) {{ - kernel_runner::run(s, a); + run_instance(s, a); // make sure F_bn0 is divisible by F_bk1 }} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{ - kernel_runner::run(s, a); + run_instance(s, a); }} else {{ - kernel_runner::run(s, a); + run_instance(s, a); }} }} else {{ - kernel_runner::run(s, a); + run_instance(s, a); }} }} template<> std::string fmha_fwd_splitkv_get_name_() {{ - using k_ = kernel_runner::fmha_kernel; /// FIXME: choose real kernel type + using k_ = instance::fmha_kernel; /// FIXME: choose real kernel type return k_::GetName(); }} """ @@ -148,7 +169,7 @@ using fmha_dtype_{F_idx} = {F_dtype}; namespace {{ template -struct kernel_runner {{ +struct instance {{ using fmha_trait = ck_tile::TileFmhaFwdSplitKVCombineTraits<{F_spad}, {F_dvpad}, {F_lse}, @@ -161,9 +182,8 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem< typename FmhaFwdTypeConfig::OaccDataType, typename FmhaFwdTypeConfig::ODataType, {F_hdim}, - {F_bm0}, - {F_bn1}, {F_mode}, + {F_bn1}, fmha_trait>; using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline< @@ -177,9 +197,7 @@ using fmha_epilogue = false, false>>; using fmha_kernel = - ck_tile::FmhaFwdSplitKVCombineKernel, - fmha_pipeline, - fmha_epilogue>; + ck_tile::FmhaFwdSplitKVCombineKernel; static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ @@ -192,7 +210,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) }}; }} -using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn1}, +using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; #include @@ -201,22 +219,22 @@ template<> void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ if (a.num_splits <= 8) {{ - kernel_runner<3>::run(s, a); + instance<3>::run(s, a); }} else if (a.num_splits <= 16) {{ - kernel_runner<4>::run(s, a); + instance<4>::run(s, a); }} else if (a.num_splits <= 32) {{ - kernel_runner<5>::run(s, a); + instance<5>::run(s, a); }} else if (a.num_splits <= 64) {{ - kernel_runner<6>::run(s, a); + instance<6>::run(s, a); }} else if (a.num_splits <= 128) {{ - kernel_runner<7>::run(s, a); + instance<7>::run(s, a); }} }} template<> std::string fmha_fwd_splitkv_combine_get_name_() {{ - using k_ = kernel_runner<6>::fmha_kernel; /// FIXME: choose real kernel type + using k_ = instance<6>::fmha_kernel; /// FIXME: choose real kernel type return k_::GetName(); }} """ @@ -231,11 +249,11 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a if(s.log_level_ > 0) std::cout << ", " << fmha_fwd_splitkv_get_name_() - << ", " << fmha_fwd_splitkv_combine_get_name_() + << ", " << fmha_fwd_splitkv_combine_get_name_() << std::flush; return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_(s_, a); }}, [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_(s_, a); }} ); }} @@ -247,12 +265,31 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const }} """ -FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && +FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; - using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; - - return fmha_fwd_splitkv_(s, a); + using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + + // get combine kernel tile sizes + using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType; + constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes::kM0; + + // make sure we can reuse the padding flags in combine kernels + static_assert({F_bm0} % kM0 == 0); + static_assert({F_bn1} % 32 == 0); + + if (t.has_lse) {{ + if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{ + return -1; + }} else {{ + using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>; + + return fmha_fwd_splitkv_(s, a); + }} + }} else {{ + using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, false, {F_squant}, {F_spad}, {F_dvpad}>; + + return fmha_fwd_splitkv_(s, a); + }} }} """ @@ -292,7 +329,7 @@ class FmhaFwdSplitKVApiTrait: if self.pipeline_tag == 'qr_async': if self.spad == 't' : return 'true' # always support else : return 'true' - elif self.pipeline_tag in ['qr']: + elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']: if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_q % {self.bm0} == 0' else: assert False @@ -303,7 +340,7 @@ class FmhaFwdSplitKVApiTrait: if self.pipeline_tag == 'qr_async': if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' - elif self.pipeline_tag in ['qr', 'qr_fp8']: + elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']: if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_k % {self.bn0} == 0' else: assert False @@ -314,7 +351,7 @@ class FmhaFwdSplitKVApiTrait: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dpad == 't': return f'a.hdim_q % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr']: + elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_q % {bk0submax} == 0' @@ -326,7 +363,7 @@ class FmhaFwdSplitKVApiTrait: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr']: + elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_v % {bk0submax} == 0' @@ -421,11 +458,11 @@ class FmhaFwdSplitKVApiPool: inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], + F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, - F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) if_j = 'if' if j == 0 else 'else if' per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' @@ -437,12 +474,11 @@ class FmhaFwdSplitKVApiPool: @dataclass class FmhaFwdSplitKVCombineTileSize: - F_bm0 : int # tile size along q seqlen F_bn1 : int # tile size along v head_dim F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy @property def name(self) -> str: - return f"b{self.F_bm0}x{self.F_bn1}" +\ + return f"b{self.F_bn1}" +\ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") @dataclass @@ -462,7 +498,7 @@ class FmhaFwdSplitKVKernel: FMHA_FWD_SPLITKV_KERNEL_BODY.format( F_idx = self.F_idx, F_hdim = self.F_hdim, - F_dtype = DTYPE_MAP[self.F_dtype], + F_dtype = FWD_DTYPE_MAP[self.F_dtype], F_bm0 = self.F_tile.F_bm0, F_bn0 = self.F_tile.F_bn0, F_bk0 = self.F_tile.F_bk0, @@ -475,14 +511,17 @@ class FmhaFwdSplitKVKernel: F_rm1 = self.F_tile.F_rm1, F_rn1 = self.F_tile.F_rn1, F_rk1 = self.F_tile.F_rk1, - F_wm = self.F_tile.F_wm, - F_wn = self.F_tile.F_wn, - F_wk = self.F_tile.F_wk, + F_wm0 = self.F_tile.F_wm0, + F_wn0 = self.F_tile.F_wn0, + F_wk0 = self.F_tile.F_wk0, + F_wm1 = self.F_tile.F_wm1, + F_wn1 = self.F_tile.F_wn1, + F_wk1 = self.F_tile.F_wk1, F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], F_spad = BOOL_MAP[self.F_pipeline.F_spad], F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_squant = BOOL_MAP[self.F_pipeline.F_squant], @@ -542,8 +581,7 @@ class FmhaFwdSplitKVCombineKernel: FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format( F_idx = self.F_idx, F_hdim = self.F_hdim, - F_dtype = DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_tile.F_bm0, + F_dtype = FWD_DTYPE_MAP[self.F_dtype], F_bn1 = self.F_tile.F_bn1, F_spad = BOOL_MAP[self.F_pipeline.F_spad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], @@ -567,17 +605,17 @@ class FmhaFwdSplitKVCombineKernel: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, -1), - '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), - ## '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), - '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), - '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), + '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + ### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': return { - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1) + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), } else: return None @@ -585,17 +623,17 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '32' : FmhaFwdSplitKVCombineTileSize(16, 16, -1), - '64' : FmhaFwdSplitKVCombineTileSize(32, 32, -1), - ## '96' : FmhaFwdSplitKVCombineTileSize(32, 64, -1), - '128' : FmhaFwdSplitKVCombineTileSize(32, 64, -1), - '256' : FmhaFwdSplitKVCombineTileSize(32, 128, -1), + '32' : FmhaFwdSplitKVCombineTileSize(32, -1), + '64' : FmhaFwdSplitKVCombineTileSize(32, -1), + ### '96' : FmhaFwdSplitKVCombineTileSize(32, -1), + '128' : FmhaFwdSplitKVCombineTileSize(32, -1), + '256' : FmhaFwdSplitKVCombineTileSize(32, -1), } elif dtype == 'fp8' or dtype == 'bf8': return { - '64' : FmhaFwdSplitKVCombineTileSize(64, 32, -1), - '128' : FmhaFwdSplitKVCombineTileSize(64, 64, -1), - '256' : FmhaFwdSplitKVCombineTileSize(64, 128, -1), + '64' : FmhaFwdSplitKVCombineTileSize(32, -1), + '128' : FmhaFwdSplitKVCombineTileSize(32, -1), + '256' : FmhaFwdSplitKVCombineTileSize(32, -1), } else: return None @@ -614,27 +652,29 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): - # TODO: use async pipeline when compiler is more stable + for mask, bias, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]): + # TODO: use async pipeline when compiler is more stable if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]: # if True: - pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) else: - pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) - pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) - pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) - pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) if receipt == 1: - pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim - pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: - # no need lse/paged-kv kernels for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, 'f', mask)) + pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 't', squant, 'f', mask)) + elif dtype in ['fp8fp16', 'fp8bf16']: + # TODO + None else: assert False return pipelines @@ -642,7 +682,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> gen = list() api_pool = FmhaFwdSplitKVApiPool(mask_impl) - for dtype in DTYPE_MAP.keys(): + for dtype in FWD_DTYPE_MAP.keys(): d = get_fmha_fwd_tile_dict_from_dtype(dtype) if d == None: continue @@ -655,9 +695,6 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if pipeline.F_spad != 't' or pipeline.F_skpad != 't': # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue - if pipeline.F_pagedkv == 't': - # we only use batch mode kernels to handle (paged-) kvcache problems - continue k = Kernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, @@ -705,7 +742,7 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis gen = list() - for dtype in DTYPE_MAP.keys(): + for dtype in FWD_DTYPE_MAP.keys(): d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype) if d == None: continue diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index 2d76627a725f86122b8a734f14e0a2cb0585c616..eaf99529f3ec90c2b835a0fc2146c54be62003e5 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -101,7 +101,7 @@ auto create_args(int argc, char* argv[]) } // different threshold for different dtype -template +template auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/) { double rtol = 1e-2; @@ -110,7 +110,7 @@ auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/) } template <> -auto get_elimit(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v) +auto get_elimit(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v) { double rtol = 1e-2; double atol = 1e-2; @@ -122,7 +122,7 @@ auto get_elimit(ck_tile::index_t hdim_q, ck_tile::index_t hdim_ return ck_tile::make_tuple(rtol, atol); } -template +template bool run(const ck_tile::ArgParser& arg_parser) { std::string data_type = arg_parser.get_str("prec"); @@ -209,7 +209,7 @@ bool run(const ck_tile::ArgParser& arg_parser) const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q); const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k); - using TypeConfig = FmhaBwdTypeConfig; + using TypeConfig = FmhaBwdTypeConfig; using QDataType = typename TypeConfig::QDataType; using KDataType = typename TypeConfig::KDataType; @@ -933,7 +933,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } // clang-format on - auto [rtol, atol] = get_elimit(hdim_q, hdim_v); + auto [rtol, atol] = get_elimit(hdim_q, hdim_v); bool dq_cur_pass = ck_tile::check_err(dq_host_result, dq_host_ref, std::string("Error: QGrad Incorrect results!"), @@ -986,11 +986,11 @@ int main(int argc, char* argv[]) const std::string data_type = arg_parser.get_str("prec"); if(data_type == "fp16") { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } else if(data_type == "bf16") { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } return -3; diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 3b21a3257f850442d52deea806863e058c5d2849..6204cbcfa8ed67de0fceaf447ca32393a4282500 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -14,11 +14,19 @@ #include #include +struct FmhaBwdFp16 +{ +}; + +struct FmhaBwdBf16 +{ +}; + template struct FmhaBwdTypeConfig; template <> -struct FmhaBwdTypeConfig +struct FmhaBwdTypeConfig { using QDataType = ck_tile::half_t; using KDataType = ck_tile::half_t; @@ -38,7 +46,7 @@ struct FmhaBwdTypeConfig }; template <> -struct FmhaBwdTypeConfig +struct FmhaBwdTypeConfig { using QDataType = ck_tile::bf16_t; using KDataType = ck_tile::bf16_t; @@ -150,113 +158,113 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) // create group mode kernel arguments if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode) { - return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.lse_ptr, - args.do_ptr, - args.d_ptr, - args.rand_val_ptr, - args.dk_ptr, - args.dv_ptr, - args.dbias_ptr, - args.dq_acc_ptr, - args.seqstart_q_ptr, - args.seqstart_k_ptr, - args.seqlen_k_ptr, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_do, - args.stride_dq_acc, - args.stride_dk, - args.stride_dv, - args.stride_dbias, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_do, - args.nhead_stride_lsed, - args.nhead_stride_dq_acc, - args.nhead_stride_dk, - args.nhead_stride_dv, - args.nhead_stride_dbias, - args.split_stride_dq_acc, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.p_drop, - args.drop_seed_offset); + return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.do_ptr, + args.d_ptr, + args.rand_val_ptr, + args.dk_ptr, + args.dv_ptr, + args.dbias_ptr, + args.dq_acc_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_do, + args.stride_dq_acc, + args.stride_dk, + args.stride_dv, + args.stride_dbias, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_do, + args.nhead_stride_lsed, + args.nhead_stride_dq_acc, + args.nhead_stride_dk, + args.nhead_stride_dv, + args.nhead_stride_dbias, + args.split_stride_dq_acc, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.drop_seed_offset); } else { // create batch mode kernel arguments - return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.lse_ptr, - args.do_ptr, - args.d_ptr, - args.rand_val_ptr, - args.dk_ptr, - args.dv_ptr, - args.dbias_ptr, - args.dq_acc_ptr, - args.seqlen_q, - args.seqlen_k, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_do, - args.stride_dq_acc, - args.stride_dk, - args.stride_dv, - args.stride_dbias, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_do, - args.nhead_stride_lsed, - args.nhead_stride_dq_acc, - args.nhead_stride_dk, - args.nhead_stride_dv, - args.nhead_stride_dbias, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - args.batch_stride_bias, - args.batch_stride_randval, - args.batch_stride_do, - args.batch_stride_lsed, - args.batch_stride_dq_acc, - args.batch_stride_dk, - args.batch_stride_dv, - args.batch_stride_dbias, - args.split_stride_dq_acc, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.p_drop, - args.drop_seed_offset); + return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.do_ptr, + args.d_ptr, + args.rand_val_ptr, + args.dk_ptr, + args.dv_ptr, + args.dbias_ptr, + args.dq_acc_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_do, + args.stride_dq_acc, + args.stride_dk, + args.stride_dv, + args.stride_dbias, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_do, + args.nhead_stride_lsed, + args.nhead_stride_dq_acc, + args.nhead_stride_dk, + args.nhead_stride_dv, + args.nhead_stride_dbias, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_do, + args.batch_stride_lsed, + args.batch_stride_dq_acc, + args.batch_stride_dk, + args.batch_stride_dv, + args.batch_stride_dbias, + args.split_stride_dq_acc, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.drop_seed_offset); } }(); diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 14291715fb0b314d06f0b27699157753a316f452..b3855e59dfc1ff149f9c365da21e64b18c07424c 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -3,6 +3,7 @@ #include "fmha_fwd.hpp" #include "ck_tile/host.hpp" +#include "ck_tile/ref/naive_attention.hpp" #include "mask.hpp" #include "rotary.hpp" #include "utils.hpp" @@ -41,7 +42,7 @@ std::ostream& operator<<(std::ostream& os, const std::vector& v) auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; - arg_parser.insert("v", "1", "weather do CPU validation or not") + arg_parser.insert("v", "1", "0:no validation, 2:cpu validation, 2:gpu validation(experimental)") .insert("mode", "0", "kernel mode. 0:batch, 1:group") .insert("b", "2", "batch size") .insert("h", "8", "num of head, for q") @@ -62,7 +63,7 @@ auto create_args(int argc, char* argv[]) "-1 to choose s_knew in [1, s] randomly.") .insert("s_kpad", "-1", - "seqlen_k stride between 2 tokens, currently used in group-mode only\n" + "seqlen_k stride between 2 batches, currently used in group-mode only\n" "for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n" "along seqlen, instead of packed. same as xformer kv_padding") .insert("d", "128", "head dim for q, k") @@ -142,7 +143,7 @@ auto create_args(int argc, char* argv[]) } // different threshold for different dtype -template +template auto get_elimit(std::string /*init_method*/) { double rtol = 1e-3; @@ -151,7 +152,7 @@ auto get_elimit(std::string /*init_method*/) } template <> -auto get_elimit(std::string /*init_method*/) +auto get_elimit(std::string /*init_method*/) { double rtol = 1e-2; double atol = 1e-2; @@ -159,7 +160,7 @@ auto get_elimit(std::string /*init_method*/) } template <> -auto get_elimit(std::string init_method) +auto get_elimit(std::string init_method) { if(init_method == "ui" || init_method == "ni") { @@ -261,7 +262,7 @@ int override_num_splits_if_necessary( return num_splits; } -template +template bool run(const ck_tile::ArgParser& arg_parser) { std::string data_type = arg_parser.get_str("prec"); @@ -294,7 +295,8 @@ bool run(const ck_tile::ArgParser& arg_parser) #if !CK_TILE_FMHA_FWD_APPENDKV_API if(seqlen_knew != 0) { - std::cerr << "kvcache is not supported. ignoring the 's_knew' option" << std::endl; + std::cerr << "fmha_fwd_appendkv() is not enabled. ignoring the 's_knew' option" + << std::endl; seqlen_knew = 0; } #endif @@ -304,8 +306,8 @@ bool run(const ck_tile::ArgParser& arg_parser) } ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim"); - if constexpr(!(std::is_same_v || - std::is_same_v)) + if constexpr(!(std::is_same_v || + std::is_same_v)) { if(0 < rotary_dim) { @@ -321,6 +323,13 @@ bool run(const ck_tile::ArgParser& arg_parser) rotary_dim = 0; } #endif + // to use fmha_fwd_appendkv(), make sure it's in batch mode + const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim); + if(need_append_kvcache && mode == mode_enum::group) + { + std::cerr << "fmha_fwd_appendkv() will be invoked. ignoring the 'mode' option" << std::endl; + mode = mode_enum::batch; + } if(!(rotary_dim <= hdim_q)) { std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl; @@ -356,22 +365,26 @@ bool run(const ck_tile::ArgParser& arg_parser) << std::endl; use_cache_batch_idx = false; } -#endif - if(0 < page_block_size && use_cache_batch_idx) +#else + if(use_cache_batch_idx) { - std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the " - "'cache_batch_idx' option" - << std::endl; - use_cache_batch_idx = false; + if(0 < page_block_size) + { + std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the " + "'cache_batch_idx' option" + << std::endl; + use_cache_batch_idx = false; + } + else if(mode == mode_enum::group) + { + std::cerr << "group mode will not use cache_batch_idx. ignoring the " + "'cache_batch_idx' option" + << std::endl; + use_cache_batch_idx = false; + } } - // the input tensor layout for kvcache is same as batch mode - const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim); +#endif const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size); - if(use_kvcache && mode != mode_enum::batch) - { - std::cerr << "kvcache enabled. ignoring the 'mode' option" << std::endl; - mode = mode_enum::batch; - } auto [seqlen_qs, seqlen_ks, seqlen_kpads] = decode_seqlen(mode, @@ -380,7 +393,7 @@ bool run(const ck_tile::ArgParser& arg_parser) arg_parser.get_str("s_k"), arg_parser.get_str("s_kpad"), /*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0, - use_kvcache); + need_append_kvcache); // compute kvcache seqlen_k (before appending knew/vnew) auto cache_seqlen_ks = seqlen_ks; std::transform(cache_seqlen_ks.begin(), @@ -416,25 +429,6 @@ bool run(const ck_tile::ArgParser& arg_parser) return atoi(squant_str.c_str()) != 0 ? true : false; }(); - float range_q = arg_parser.get_float("range_q"); - float range_k = arg_parser.get_float("range_k"); - float range_v = arg_parser.get_float("range_v"); - float range_p = arg_parser.get_float("range_p"); - float range_o = arg_parser.get_float("range_o"); - - float dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - - float scale_p = 1.f; - float scale_o = 1.f; - - if(squant) - { - scale_s = scale_s * (range_q / dtype_max) * (range_k / dtype_max); - scale_p = dtype_max / range_p; - // scale_p = [max(fp8_t)/range_o] * [range_p/max(fp8_t)] * [range_v/max(fp8_t)] - scale_o = range_p * range_v / range_o / dtype_max; - } - std::string vlayout = arg_parser.get_str("vlayout"); bool lse = arg_parser.get_bool("lse"); @@ -454,7 +448,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } bool s_randval = false; - if(p_drop > 0.0f && do_validation) + if(p_drop > 0.0f && do_validation != 0) { s_randval = true; } @@ -487,7 +481,7 @@ bool run(const ck_tile::ArgParser& arg_parser) const auto seqstart_k_host = to_seqstarts(seqlen_ks); const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); - using TypeConfig = FmhaFwdTypeConfig; + using TypeConfig = FmhaFwdTypeConfig; using QDataType = typename TypeConfig::QDataType; using KDataType = typename TypeConfig::KDataType; @@ -501,6 +495,28 @@ bool run(const ck_tile::ArgParser& arg_parser) using OaccDataType = typename TypeConfig::OaccDataType; using ODataType = typename TypeConfig::ODataType; + float range_q = arg_parser.get_float("range_q"); + float range_k = arg_parser.get_float("range_k"); + float range_v = arg_parser.get_float("range_v"); + float range_p = arg_parser.get_float("range_p"); + float range_o = arg_parser.get_float("range_o"); + + float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float v_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float p_dtype_max = v_dtype_max; // assume p and v is the same type + float o_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + + float scale_p = 1.f; + float scale_o = 1.f; + + if(squant) + { + scale_s = scale_s * (range_q / q_dtype_max) * (range_k / k_dtype_max); + scale_p = p_dtype_max / range_p; + scale_o = (o_dtype_max / range_o) * (range_p / p_dtype_max) * (range_v / v_dtype_max); + } + // accumulation numbers for performance evaluation std::size_t flop = 0, num_byte = 0; auto max_seqlen_q = @@ -697,14 +713,14 @@ bool run(const ck_tile::ArgParser& arg_parser) else if(init_method == "ufq" || init_method == "uf:q" || init_method == "3") // suitable for fp8 quantization { - ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(q_host); - ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(k_host); - ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(knew_host); - ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(v_host); - ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(vnew_host); + ck_tile::FillUniformDistribution{-q_dtype_max, q_dtype_max, seed}(q_host); + ck_tile::FillUniformDistribution{-k_dtype_max, k_dtype_max, seed}(k_host); + ck_tile::FillUniformDistribution{-k_dtype_max, k_dtype_max, seed}(knew_host); + ck_tile::FillUniformDistribution{-v_dtype_max, v_dtype_max, seed}(v_host); + ck_tile::FillUniformDistribution{-v_dtype_max, v_dtype_max, seed}(vnew_host); // bias_fp8 = qscale_bias * bias_fp32 - float qscale_bias = (dtype_max / range_q) * (dtype_max / range_k); + float qscale_bias = (q_dtype_max / range_q) * (k_dtype_max / range_k); // Assume bias is in [-1.f, 1.f] in original fp32 ck_tile::FillUniformDistribution{-qscale_bias, qscale_bias, seed}(bias_host); } @@ -741,8 +757,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); - ck_tile::DeviceMem seqlen_k_buf( - use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.size() * sizeof(int32_t) : 0); + ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || + 0 <= seqlen_kpads[0] + ? seqlen_ks.size() * sizeof(int32_t) + : 0); ck_tile::DeviceMem cache_seqlen_k_buf( need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0); ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes()); @@ -763,7 +781,9 @@ bool run(const ck_tile::ArgParser& arg_parser) seqstart_q.ToDevice(seqstart_q_host.data()); seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data() : seqstart_k_with_padding_host.data()); - seqlen_k_buf.ToDevice(use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.data() : nullptr); + seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0] + ? seqlen_ks.data() + : nullptr); cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr); rotary_cos_buf.ToDevice(rotary_cos_host.data()); rotary_sin_buf.ToDevice(rotary_sin_host.data()); @@ -976,8 +996,9 @@ bool run(const ck_tile::ArgParser& arg_parser) (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr); args.seqstart_k_ptr = (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr); - args.seqlen_k_ptr = - (use_kvcache || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() : nullptr); + args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0] + ? seqlen_k_buf.GetDeviceBuffer() + : nullptr); args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled) args.max_seqlen_q = max_seqlen_q; @@ -1029,6 +1050,7 @@ bool run(const ck_tile::ArgParser& arg_parser) (0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr); args.batch_stride_block_table = batch_stride_block_table; args.page_block_size = page_block_size; + args.is_gappy = false; // use 'false' for flash-attention integration args.cache_batch_idx = (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr); @@ -1100,25 +1122,76 @@ bool run(const ck_tile::ArgParser& arg_parser) << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec << " GB/s" << std::flush; - if(!do_validation) + if(do_validation == 0) { std::cout << std::flush << std::endl; return true; } + if(do_validation == 2) + { + // NOTE: use gpu to do validation + ck_tile::naive_attention_fwd_traits naive_t; + naive_t.q_type = data_type; + naive_t.k_type = data_type; + naive_t.v_type = data_type; + naive_t.o_type = data_type; + naive_t.q_layout = i_perm == 1 ? "bhsd" : "bshd"; + naive_t.k_layout = i_perm == 1 ? "bhsd" : "bshd"; + naive_t.v_layout = i_perm == 1 ? "bhsd" : "bshd"; + naive_t.o_layout = o_perm == 1 ? "bhsd" : "bshd"; + naive_t.variation = 0; // TODO? + naive_t.quant_algo = 0; + + ck_tile::DeviceMem o_naive_buf(o_host.get_element_space_size_in_bytes()); + + ck_tile::naive_attention_fwd_args naive_a; + naive_a.q_ptr = q_buf.GetDeviceBuffer(); + naive_a.k_ptr = k_buf.GetDeviceBuffer(); + naive_a.v_ptr = v_buf.GetDeviceBuffer(); + naive_a.o_ptr = o_naive_buf.GetDeviceBuffer(); + naive_a.scale_s = scale_s; + naive_a.context_len_ptr = nullptr; // used when seqlen kv come from a pointer + naive_a.page_table_ptr = + nullptr; // [batch, num_blocks] seqlen_kv is in different block(paged attn) + naive_a.hdim = hdim_q; + naive_a.hdim_v = hdim_v; // could be cross-attn, where V and Q/K hdim are different + naive_a.batch_q = batch; + naive_a.batch_kv = batch; + naive_a.batch_ratio_kv = 1; // batch_q / batch_kv + naive_a.seqlen_q = seqlen_qs[0]; + naive_a.seqlen_kv = seqlen_ks[0]; // if context_len_ptr is not nullptr, ignore this field + naive_a.nhead_q = nhead; + naive_a.nhead_kv = nhead_k; + naive_a.nhead_ratio_kv = naive_a.nhead_q / naive_a.nhead_kv; // nhead_q / nhead_kv + naive_a.page_size = 0; // if paged, the seqlen-kv for each block + + ck_tile::stream_config naive_s{}; + + naive_attention_fwd(naive_t, naive_a, naive_s); + + auto o_naive_ref = o_naive_buf.ToHost(); + o_buf.FromDevice(o_host.data()); // TODO: ugly + + auto [rtol_, atol_] = get_elimit(init_method); + bool pass_ = ck_tile::check_err( + o_host, o_naive_ref, std::string("OUT Error: Incorrect results!"), rtol_, atol_); + std::cout << ", valid:" << (pass_ ? "y" : "n") << std::flush << std::endl; + return pass_; + } o_buf.FromDevice(o_host.data()); lse_buf.FromDevice(lse_host.data()); randval_buf.FromDevice(randval_host.data()); auto p_compute_element_func = [&]() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) return ck_tile::scales{scale_p}; else return ck_tile::identity{}; }(); auto oacc_element_func = [&]() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) return ck_tile::composes(ck_tile::saturates{}, ck_tile::scales{scale_o}); else @@ -1168,7 +1241,7 @@ bool run(const ck_tile::ArgParser& arg_parser) { decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths()); - auto [rotary_cos_slice, rotary_sin_slice] = + auto [rotary_cos_slice, rotary_sin_slice] = slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], real_seqlen_q); ck_tile::reference_batched_rotary_position_embedding( @@ -1184,13 +1257,13 @@ bool run(const ck_tile::ArgParser& arg_parser) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]); }); - } else { + } else { k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[1] % page_block_size, i[0] / nr, i[2]); }); } } else -#endif +#endif { if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); }); else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); }); @@ -1211,7 +1284,7 @@ bool run(const ck_tile::ArgParser& arg_parser) { knew_host_ref_ro.emplace(knew_host_ref.get_lengths()); - auto [rotary_cos_slice, rotary_sin_slice] = + auto [rotary_cos_slice, rotary_sin_slice] = slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], seqlen_knew); ck_tile::reference_batched_rotary_position_embedding( @@ -1233,19 +1306,19 @@ bool run(const ck_tile::ArgParser& arg_parser) if(0 < page_block_size) { if(is_v_rowmajor) { if(i_perm) { - v_host_ref.ForEach([&](auto& self, auto i) { - self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]); + v_host_ref.ForEach([&](auto& self, auto i) { + self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]); }); } else { - v_host_ref.ForEach([&](auto& self, auto i) { + v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[2] % page_block_size, i[0] / nr, i[1]); }); } } - else + else { - if(i_perm) { - v_host_ref.ForEach([&](auto& self, auto i) { + if(i_perm) { + v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size); }); } else { @@ -1440,7 +1513,7 @@ bool run(const ck_tile::ArgParser& arg_parser) else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); }); // clang-format on - auto [rtol, atol] = get_elimit(init_method); + auto [rtol, atol] = get_elimit(init_method); bool cur_pass = ck_tile::check_err( o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); pass &= cur_pass; @@ -1497,15 +1570,15 @@ int main(int argc, char* argv[]) const std::string data_type = arg_parser.get_str("prec"); if(data_type == "fp16") { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } else if(data_type == "bf16") { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } else if(data_type == "fp8") { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } return -3; diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 251e61bc763d75b67ab1b8b2f7293906702888d3..765c221a7b17630aa3e09786e9430c88df45069c 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -16,11 +16,35 @@ #include #include +struct FmhaFwdFp16 +{ +}; + +struct FmhaFwdBf16 +{ +}; + +struct FmhaFwdFp8 +{ +}; + +struct FmhaFwdBf8 +{ +}; + +struct FmhaFwdFp8Fp16 +{ +}; + +struct FmhaFwdFp8Bf16 +{ +}; + template struct FmhaFwdTypeConfig; template <> -struct FmhaFwdTypeConfig +struct FmhaFwdTypeConfig { using QDataType = ck_tile::half_t; using KDataType = ck_tile::half_t; @@ -36,7 +60,7 @@ struct FmhaFwdTypeConfig }; template <> -struct FmhaFwdTypeConfig +struct FmhaFwdTypeConfig { using QDataType = ck_tile::bf16_t; using KDataType = ck_tile::bf16_t; @@ -52,7 +76,7 @@ struct FmhaFwdTypeConfig }; template <> -struct FmhaFwdTypeConfig +struct FmhaFwdTypeConfig { using QDataType = ck_tile::fp8_t; using KDataType = ck_tile::fp8_t; @@ -68,7 +92,7 @@ struct FmhaFwdTypeConfig }; template <> -struct FmhaFwdTypeConfig +struct FmhaFwdTypeConfig { using QDataType = ck_tile::bf8_t; using KDataType = ck_tile::bf8_t; @@ -165,6 +189,8 @@ struct fmha_fwd_splitkv_args void* block_table_ptr; ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr + bool is_gappy; // differentiate seqstart_k_ptr usage. only used if 'block_table_ptr' is not + // nullptr. const void* cache_batch_idx; @@ -173,9 +199,21 @@ struct fmha_fwd_splitkv_args // seqlen_k = kargs.seqlen_k // group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] - // kvcache mode (use same kernel as batch mode): + // or kargs.seqlen_k_ptr[b] + // + // batch mode (kvcache): // seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.seqlen_k_ptr[b] + // group mode (kvcache): + // seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] + // + // when is_gappy=true: + // seqlen_k = kargs.seqlen_k_ptr[b] + // seqstart_k_ptr[b] now store local offset of each batch + // + // when is_gappy=false: // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] + // or kargs.seqlen_k_ptr[b] const void* seqstart_q_ptr; const void* seqstart_k_ptr; const void* seqlen_k_ptr; @@ -251,7 +289,7 @@ struct fmha_fwd_appendkv_args ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr - const void* cache_batch_idx; + const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache) ck_tile::index_t stride_q; ck_tile::index_t stride_k; @@ -278,92 +316,102 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) // create group mode kernel arguments if constexpr(FmhaKernel::kIsGroupMode) { - return FmhaKernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.rand_val_ptr, - args.lse_ptr, - args.o_ptr, - args.seqstart_q_ptr, - args.seqstart_k_ptr, - args.seqlen_k_ptr, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale_s, - args.scale_p, - args.scale_o, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_lse, - args.nhead_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.p_drop, - args.s_randval, - args.drop_seed_offset); + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); } else { // create batch mode kernel arguments - return FmhaKernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.rand_val_ptr, - args.lse_ptr, - args.o_ptr, - args.seqlen_q, - args.seqlen_k, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale_s, - args.scale_p, - args.scale_o, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_lse, - args.nhead_stride_o, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - args.batch_stride_bias, - args.batch_stride_randval, - args.batch_stride_lse, - args.batch_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.p_drop, - args.s_randval, - args.drop_seed_offset); + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_lse, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); } }(); - dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); - return ck_tile::make_tuple(kargs, grids); + if constexpr(FmhaKernel::kIsGroupMode) + { + dim3 grids = FmhaKernel::GridSize( + args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr); + return ck_tile::make_tuple(kargs, grids); + } + else + { + dim3 grids = + FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false); + return ck_tile::make_tuple(kargs, grids); + } } template @@ -389,6 +437,10 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.nhead_q, args.nhead_q / args.nhead_k, args.num_splits, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.is_gappy, args.scale_s, args.scale_p, args.stride_q, @@ -458,8 +510,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) } }(); - dim3 grids = - Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.num_splits); + dim3 grids = Kernel::GridSize( + args.batch, args.nhead_q, args.nhead_k, args.max_seqlen_q, args.hdim_v, args.num_splits); return ck_tile::make_tuple(kargs, grids); } @@ -667,7 +719,6 @@ std::string fmha_fwd_splitkv_get_name_(); template ; static constexpr bool kIsGroupMode = kIsGroupMode_; - static constexpr ck_tile::index_t kM0 = kM0_; static constexpr ck_tile::index_t kN1 = kN1_; static constexpr bool kStoreLse = kStoreLse_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; diff --git a/example/ck_tile/01_fmha/utils.hpp b/example/ck_tile/01_fmha/utils.hpp index 996032a7177977867dd00f564aea1ceacba86ed1..faf3f08437a7d910113104e16b90ccf67f7b2598 100644 --- a/example/ck_tile/01_fmha/utils.hpp +++ b/example/ck_tile/01_fmha/utils.hpp @@ -145,7 +145,7 @@ decode_seqlen(mode_enum mode, std::string k_val, std::string k_pad_val, ck_tile::index_t seqlen_k_min = 0, - bool use_kvcache = false, + bool need_append_kvcache = false, std::optional seed = std::nullopt) { #define _S2I_(str_) static_cast(std::atoi((str_).c_str())) @@ -159,7 +159,7 @@ decode_seqlen(mode_enum mode, const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k); std::vector seqlen_ks(batch, seqlen_k_max); - if(1 < batch && use_kvcache) + if(1 < batch && need_append_kvcache) { // to keep the original s_k value, we always use seqlen_k_max in first batch randints(std::next(seqlen_ks.begin()), diff --git a/example/ck_tile/02_layernorm2d/CMakeLists.txt b/example/ck_tile/02_layernorm2d/CMakeLists.txt index 1bf74bc0553296f498004c024477034dd31797d0..fa69ac0f7ac8b2f3044e813cb22a55e294eec832 100644 --- a/example/ck_tile/02_layernorm2d/CMakeLists.txt +++ b/example/ck_tile/02_layernorm2d/CMakeLists.txt @@ -33,7 +33,7 @@ target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS}) set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations -list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal --offload-compress) target_compile_options(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS}) diff --git a/example/ck_tile/02_layernorm2d/README.md b/example/ck_tile/02_layernorm2d/README.md index 3573d70cd2615be0e1cc65a62613e4de0015f636..817f62dae7a6c876d4879581546d81e4fd144355 100644 --- a/example/ck_tile/02_layernorm2d/README.md +++ b/example/ck_tile/02_layernorm2d/README.md @@ -59,7 +59,7 @@ args: -kname print kernel name or not (default:1) -prec_i input precision (default:fp16) -prec_o output precision, set auto will be the same as input (default:auto) - -prec_sx output quant scale type, set auto will be the same as input. used when fquant=1 (default:auto) + -prec_sm output quant scale type, set auto will be the same as input. used when fquant=1 (default:auto) -prec_sy output quant scale type, set auto will be the same as input. used when fquant=1 or 2 (default:auto) -fadd fused-add, 0:no fused add, 1:preadd+store, 2:preadd only (default:0) -fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0) @@ -69,7 +69,7 @@ args: ``` ## limitations -Note that `fquant=2`, `fadd=2`, `prec_sx/prec_sy` other than `fp32` are not by default generated. Though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, `N>8192` case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet. If need suport `N>8192` and `fused+residual+store`, you can use this example together with `12_smoothquant`, to construct layernorm+residual, and smoothquant, 2 kernels for this purpose. +Note that `fquant=2`, `fadd=2`, `prec_sm/prec_sy` other than `fp32` are not by default generated. Though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, `N>8192` case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet. If need suport `N>8192` and `fused+residual+store`, you can use this example together with `12_smoothquant`, to construct layernorm+residual, and smoothquant, 2 kernels for this purpose. ``` # some case diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index ca9e432a4f1ff914abaa851ffc0dacf7d0716f56..700b007fad5990a130e76eb9917feb46d7d3d085 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation import argparse @@ -23,6 +23,10 @@ def get_if_str(idx, total, lase_else = True): else: return 'else if' +XBIAS_ENUM_STR_MAP = [ + 'no', + 'xbias'] # pre-norm add bias + FUSED_ADD_ENUM_STR_MAP = [ 'no', 'pras', # pre-norm @@ -35,7 +39,8 @@ FUSED_FUSED_SWEEP_STR_MAP = [ DATA_TYPE_MAP = {'fp32' : 'float', 'fp16' : 'ck_tile::fp16_t', 'bf16' : 'ck_tile::bf16_t', - 'int8' : 'ck_tile::int8_t'} + 'int8' : 'ck_tile::int8_t', + 'fp8' : 'ck_tile::fp8_t'} def BOOL_MAP(b_) -> str: if b_: @@ -48,7 +53,7 @@ class layernorm_fwd_codegen: // this is used to pattern-match internl kernel implementation, not to instantiate kernel template struct layernorm2d_fwd_traits_ { using XDataType = ck_tile::remove_cvref_t; using YDataType = ck_tile::remove_cvref_t; - using XScaleDataType = ck_tile::remove_cvref_t; + using SmoothScaleDataType = ck_tile::remove_cvref_t; using YScaleDataType = ck_tile::remove_cvref_t; static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; @@ -120,14 +127,16 @@ struct layernorm2d_fwd_traits_ static constexpr bool kPadN = kPadN_; static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; static constexpr bool kFastFDiv = kFastFDiv_; + static constexpr bool kWelford = kWelford_; static constexpr bool kTwoPass = kTwoPass_; + static constexpr ck_tile::index_t kXbias = kXbias_; static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_; static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_; }; template using traits_ = layernorm2d_fwd_traits_; """ API_COMMON_HEADER = """ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include "layernorm2d_fwd.hpp" @@ -177,26 +190,29 @@ float layernorm2d_fwd_(const S& s, A a) {{ using XDataType = typename Traits_::XDataType; using YDataType = typename Traits_::YDataType; - using XScaleDataType = typename Traits_::XScaleDataType; + using SmoothScaleDataType = typename Traits_::SmoothScaleDataType; using YScaleDataType = typename Traits_::YScaleDataType; - using ComputeDataType = typename LayerNormTypeConfig::ComputeDataType; + using ComputeDataType = typename LayerNormTypeConfig::ComputeDataType; using PipelineTraits = ck_tile::Layernorm2dFwdTraits(Traits_::kXbias), static_cast(Traits_::kFusedAdd), static_cast(Traits_::kFusedQuant)>; using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem< - typename LayerNormTypeConfig::XDataType, - typename LayerNormTypeConfig::GammaDataType, - typename LayerNormTypeConfig::BetaDataType, - typename LayerNormTypeConfig::ComputeDataType, - typename LayerNormTypeConfig::YDataType, - typename LayerNormTypeConfig::MeanDataType, - typename LayerNormTypeConfig::InvStdDataType, - typename LayerNormTypeConfig::XScaleDataType, - typename LayerNormTypeConfig::YScaleDataType, + typename LayerNormTypeConfig::XDataType, + typename LayerNormTypeConfig::XBiasDataType, + typename LayerNormTypeConfig::GammaDataType, + typename LayerNormTypeConfig::BetaDataType, + typename LayerNormTypeConfig::ComputeDataType, + typename LayerNormTypeConfig::YDataType, + typename LayerNormTypeConfig::MeanDataType, + typename LayerNormTypeConfig::InvStdDataType, + typename LayerNormTypeConfig::SmoothScaleDataType, + typename LayerNormTypeConfig::YScaleDataType, typename Traits_::Shape, PipelineTraits>; @@ -204,12 +220,13 @@ float layernorm2d_fwd_(const S& s, A a) using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass; using Pipeline = std::conditional_t; - using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem; + using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem; using Default2DEpilogue = ck_tile::Default2DEpilogue; static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1; - using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem>; + static constexpr bool UseRawStore = sizeof(YDataType) == 4; + using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem>; using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue; @@ -233,7 +250,7 @@ float layernorm2d_fwd_(const S& s, A a) API_BASE = """ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include "layernorm2d_fwd.hpp" @@ -269,12 +286,12 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, INSTANCE_BASE = """ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "layernorm2d_fwd_api_common.hpp" // clang-format off -// prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf 2p add sweep +// prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf welford 2p xbias add sweep {F_instance_def} // clang-format on @@ -284,6 +301,10 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, self.working_path = working_path self.kernel_filter = kernel_filter + class k_xbias_enum(IntEnum): + F_NO_XBIAS = 0 + F_ADD_XBIAS = 1 + class k_fuesd_add_enum(IntEnum): F_NO_ADD = 0 F_PRE_ADD = 1 @@ -299,6 +320,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, F_kPadN : bool F_kSaveMeanInvStd : bool F_kTwoPass : bool + F_kXbias : Any #: layernorm_fwd_codegen.k_bias_enum F_kFusedAdd : Any #: layernorm_fwd_codegen.k_fuesd_add_enum F_kFusedQuant : Any #: layernorm_fwd_codegen.k_fused_sweep_enum @@ -315,6 +337,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, @dataclass class k_problem: F_XDataType : str + F_XBiasDataType : str F_GammaDataType : str F_BetaDataType : str F_ComputeDataType : str @@ -352,7 +375,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, class h_traits: F_XDataType : str F_YDataType : str - F_XScaleDataType : str + F_SmoothScaleDataType : str F_YScaleDataType : str F_Repeat_M : int F_Repeat_N : int @@ -362,15 +385,17 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, F_kPadN : bool F_kSaveMeanInvStd_ : bool F_kFastFDiv_ : bool + F_kWelford_ : bool F_kTwoPass_ : bool + F_kXbias_ : int F_kFusedAdd : int F_kFusedQuant : int @property def trait_name(self) ->str: - t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' - t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}' - t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}' + t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' + t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}, {BOOL_MAP(self.F_kWelford_):5}' + t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kXbias:4}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}' return t_ # string when calling this kernel @@ -388,6 +413,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, class h_instance: F_DataTypePair : str F_N : str + F_xbias : int F_add : int F_sweep : int instance_list : List[Any] # List[h_traits] @@ -397,6 +423,8 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, prec_i, prec_o = self.F_DataTypePair.split(',') dtype_str = f'{prec_i}' if prec_i == prec_o else f'{prec_i}_{prec_o}' nnn = f'layernorm2d_fwd_{dtype_str}_n{self.F_N}' + if self.F_xbias != 0: + nnn = nnn + '_' + XBIAS_ENUM_STR_MAP[self.F_xbias] if self.F_add != 0: nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add] if self.F_sweep != 0: @@ -422,11 +450,10 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, def name_common_header(self) -> str: return 'layernorm2d_fwd_api_common' - @property - def content_api(self) -> str: + def content_api(self, args) -> str: # 1 sort based on dtype t_dtype_dict = dict() - blobs = self.get_blobs() + blobs = self.get_blobs(args) for blob in blobs: if blob.F_DataTypePair not in t_dtype_dict: t_dtype_dict[blob.F_DataTypePair] = {} @@ -451,19 +478,19 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, if ins.F_kFusedQuant == 0: _sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant) elif ins.F_kFusedQuant == 1: - _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sx == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format( - f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_XScaleDataType, f_sy_type=ins.F_YScaleDataType) + _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sm == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format( + f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_SmoothScaleDataType, f_sy_type=ins.F_YScaleDataType) elif ins.F_kFusedQuant == 2: _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format( f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType) - _cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format( - f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd, + _cond = '((a.n % {f_vec_n} == 0) && (t.xbias == {f_xbias}) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format( + f_vec_n = ins.F_Vector_N, f_xbias = ins.F_kXbias, f_fused_add = ins.F_kFusedAdd, f_sweep_cond = _sweep_cond) inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False), F_VEC_COND = _cond, F_instance_func=ins.call_name) #inner_str = inner_str + vec_str - n_cnd = f'(a.n <= {n_})' if (i_n < len(blob_per_t) - 1) else '' - n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t)), F_N_COND=n_cnd, F_inner_dispatch=inner_str) + n_cnd = f'(a.n <= {n_})' if isinstance(n_, int) else '' + n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t), not isinstance(n_, int)), F_N_COND=n_cnd, F_inner_dispatch=inner_str) prec_i, prec_o = dtype_.split(',') d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str) @@ -474,77 +501,80 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, def content_common_header(self) -> str: return self.API_COMMON_HEADER.format(F_traits_define=self.API_TRAITS_DEFINE) - def get_blobs(self): + def get_blobs(self, args): h_traits = layernorm_fwd_codegen.h_traits h_instance = layernorm_fwd_codegen.h_instance - dynamic_quant_out_dtype = ['int8'] + dynamic_quant_out_dtype = ['int8', 'fp8'] # some predefined support range # (prec_i,prec_o) for simplicity this string will be used as key for dict scale_list = [('fp32,fp32')] dtype_list = [('fp16,fp16'), ('bf16,bf16'), - ('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out + ('fp16,int8'), ('bf16,int8'), + ('fp16,fp8'), ('bf16,fp8')] # NOTE: only fused-dynamic-quant use int8 or fp8 out + types_8bit = ('int8', 'fp8') + types_16bit = ('int16', 'fp16', 'bf16') #fused_add_list = [0, 1, 2] #fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused dynamic quant + xbias_list = [0, 1] fused_add_list = [0, 1] fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant - - # rm rn tm tn vn pd mv fdiv 2p add sweep - h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, False, 0, 0)], - '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, False, 0, 0)], - '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, False, 0, 0)], - '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, False, 0, 0)], - '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, False, 0, 0)], - '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, False, 0, 0)], - '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, False, 0, 0)], - '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, False, 0, 0)], - '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, False, 0, 0)], - '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, False, 0, 0)], - '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, False, 0, 0)], - '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, False, 0, 0)], - 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, 0, 0)]} + # rm rn tm tn vn pd mv fdiv welford 2p xbias add sweep + h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], + '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], + '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], + '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], + '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], + '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, True, False, 0, 0, 0)], + '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, True, False, 0, 0, 0)], + '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, True, False, 0, 0, 0)], + '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], + '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], + '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], + '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], + 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, True, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, True, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, True, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, True, 0, 0, 0)]} total_blob = list() for hs_key in h_trait_dict: hs = h_trait_dict[hs_key] current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N - for dtype, scale_type, fused_add, fused_quant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list): + for dtype, scale_type, xbias, fused_add, fused_quant in itertools.product(dtype_list, scale_list, xbias_list, fused_add_list, fused_sweep_list): prec_i, prec_o = dtype.split(',') - scale_x, scale_y = scale_type.split(',') + scale_sm, scale_y = scale_type.split(',') if prec_o in dynamic_quant_out_dtype and fused_quant != 1: continue # skip non dynamic quant case if fused_quant == 1 and hs_key == 'big': @@ -554,20 +584,32 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, h_ = copy.copy(chs_) # copy the base instance out h_.F_XDataType = prec_i h_.F_YDataType = prec_o - h_.F_XScaleDataType = scale_y - h_.F_YScaleDataType = scale_x + h_.F_SmoothScaleDataType = scale_sm + h_.F_YScaleDataType = scale_y + h_.F_kXbias = xbias h_.F_kFusedAdd = fused_add h_.F_kFusedQuant = fused_quant + # disable welford update for 8bit and 16 bit smallN + if not h_.F_kTwoPass_: + #disable 16 bit when set args disable_16b_welford + if args.disable_16b_welford and prec_i in types_16bit: + h_.F_kWelford_ = False + #disable 8bit by default + elif prec_i in types_8bit or prec_o in types_8bit: + h_.F_kWelford_ = False + #disable 16bit small N + elif prec_i in types_16bit and hs_key == '64': + h_.F_kWelford_ = False current_hs.append(h_) # + "\n" #f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_ current_n_str = 'big' if hs_key == 'big' else current_n - total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, current_hs)) + total_blob.append(h_instance(dtype, current_n_str, xbias, fused_add, fused_quant, current_hs)) return total_blob - def list_blobs(self) -> None: + def list_blobs(self, args) -> None: w_p = Path(self.working_path) list_p = w_p / 'layernorm2d_fwd_blobs.txt' - blobs = self.get_blobs() + blobs = self.get_blobs(args) with list_p.open('w') as list_f: # api related file list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n") @@ -576,11 +618,12 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, for b in blobs: list_f.write(str(w_p / (b.name + ".cpp")) + "\n") - def gen_blobs(self) -> None: + def gen_blobs(self, args) -> None: w_p = Path(self.working_path) - (w_p / (self.name_api + ".cpp")).write_text(self.content_api) + w_str = self.content_api(args) + (w_p / (self.name_api + ".cpp")).write_text(w_str) (w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header) - blobs = self.get_blobs() + blobs = self.get_blobs(args) for b in blobs: (w_p / (b.name + ".cpp")).write_text(b.content) @@ -588,14 +631,14 @@ def list_blobs(args): api_list = args.api.split(',') for api in api_list: if api == 'fwd': - layernorm_fwd_codegen(args.working_path, args.filter).list_blobs() + layernorm_fwd_codegen(args.working_path, args.filter).list_blobs(args) def gen_blobs(args): api_list = args.api.split(',') for api in api_list: if api == 'fwd': - layernorm_fwd_codegen(args.working_path, args.filter).gen_blobs() + layernorm_fwd_codegen(args.working_path, args.filter).gen_blobs(args) if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -663,6 +706,13 @@ if __name__ == "__main__": help="codegen receipt." ) + parser.add_argument( + "--disable_16b_welford", + default=False, + required=False, + help="enable/disable welford for 16bit datatype n > 64" + ) + args = parser.parse_args() # print(f'{args.list_blobs}-{args.gen_blobs}') diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp index b49c04619d54c6c401128739a236259e04c54dd7..b72485222e6a791d5e270db650be1725fff161e6 100644 --- a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp @@ -20,6 +20,14 @@ auto get_elimit() return ck_tile::make_tuple(rtol, atol); } +template <> +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1.0; + return ck_tile::make_tuple(rtol, atol); +} + auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; @@ -35,12 +43,13 @@ auto create_args(int argc, char* argv[]) .insert("kname", "1", "print kernel name or not") .insert("prec_i", "fp16", "input precision") .insert("prec_o", "auto", "output precision, set auto will be the same as input") - .insert("prec_sx", + .insert("prec_sm", "auto", "output quant scale type, set auto will use fp32. used when fquant=1") .insert("prec_sy", "auto", "output quant scale type, set auto will use fp32. used when fquant=1 or 2") + .insert("xbias", "0", "add bias, 0:no add, 1:add bias before fadd") .insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only") .insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant") .insert("warmup", "5", "cold iter") @@ -52,7 +61,7 @@ auto create_args(int argc, char* argv[]) template bool run(const ck_tile::ArgParser& arg_parser) @@ -74,15 +83,15 @@ bool run(const ck_tile::ArgParser& arg_parser) float epsilon = arg_parser.get_float("e"); std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_o = arg_parser.get_str("prec_o"); - std::string prec_sx = arg_parser.get_str("prec_sx"); + std::string prec_sm = arg_parser.get_str("prec_sm"); std::string prec_sy = arg_parser.get_str("prec_sy"); if(prec_o == "auto") { prec_o = prec_i; } - if(prec_sx == "auto") + if(prec_sm == "auto") { - prec_sx = "fp32"; + prec_sm = "fp32"; } if(prec_sy == "auto") { @@ -93,20 +102,25 @@ bool run(const ck_tile::ArgParser& arg_parser) int do_validation = arg_parser.get_int("v"); int warmup = arg_parser.get_int("warmup"); int repeat = arg_parser.get_int("repeat"); + int xbias = arg_parser.get_int("xbias"); int fused_add = arg_parser.get_int("fadd"); int fused_quant = arg_parser.get_int("fquant"); - if(fused_quant == 1 && prec_o != "int8") + if(fused_quant == 1 && prec_o != "int8" && prec_o != "fp8") { - std::cout << "if fused_quant is 1, only support \"-prec_o=int8\" case" << std::endl; + std::cout + << "if fused_quant is 1 or 2, only support \"-prec_o=int8\" or \"-prec_o=fp8\" cases." + << std::endl; return false; } assert(x_stride >= n); - using TypeConfig = LayerNormTypeConfig; + using TypeConfig = + LayerNormTypeConfig; using XDataType = typename TypeConfig::XDataType; using YDataType = typename TypeConfig::YDataType; + using XBiasDataType = typename TypeConfig::XBiasDataType; using GammaDataType = typename TypeConfig::GammaDataType; using BetaDataType = typename TypeConfig::BetaDataType; using XResidualDataType = XDataType; @@ -121,6 +135,7 @@ bool run(const ck_tile::ArgParser& arg_parser) // host verify ck_tile::HostTensor x_host({m, n}, {x_stride, 1}); + ck_tile::HostTensor x_bias_host({n}); ck_tile::HostTensor gamma_host({n}); ck_tile::HostTensor beta_host({n}); @@ -135,30 +150,33 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor y_scale_host_ref({m}); ck_tile::HostTensor y_scale_host_dev({m}); - ck_tile::HostTensor x_scale_host({n}); - ck_tile::HostTensor x_scale_host_dev({n}); + ck_tile::HostTensor sm_scale_host({n}); + ck_tile::HostTensor sm_scale_host_dev({n}); ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution{-.5f, .5f}(x_residual_host); - ck_tile::FillUniformDistribution{-1.f, 1.f}(x_scale_host); + ck_tile::FillUniformDistribution{-1.f, 1.f}(sm_scale_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(x_bias_host); ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); ck_tile::FillUniformDistribution{-.5f, .5f}(beta_host); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem x_bias_buf(x_bias_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes()); - ck_tile::DeviceMem x_scale_buf(x_scale_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sm_scale_buf(sm_scale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes()); x_buf.ToDevice(x_host.data()); + x_bias_buf.ToDevice(x_bias_host.data()); gamma_buf.ToDevice(gamma_host.data()); beta_buf.ToDevice(beta_host.data()); x_residual_buf.ToDevice(x_residual_host.data()); - x_scale_buf.ToDevice(x_scale_host.data()); + sm_scale_buf.ToDevice(sm_scale_host.data()); auto prec_str = [&]() { auto base_str = prec_i; @@ -179,11 +197,12 @@ bool run(const ck_tile::ArgParser& arg_parser) << ", yr_stride:" << yr_stride << std::flush; layernorm2d_fwd_traits traits{ - prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant}; + prec_i, prec_o, prec_sm, prec_sy, SaveMeanVar, xbias, fused_add, fused_quant}; layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr, - fused_quant == 1 ? x_scale_buf.GetDeviceBuffer() : nullptr, + fused_quant == 1 ? sm_scale_buf.GetDeviceBuffer() : nullptr, + x_bias_buf.GetDeviceBuffer(), gamma_buf.GetDeviceBuffer(), beta_buf.GetDeviceBuffer(), @@ -210,8 +229,9 @@ bool run(const ck_tile::ArgParser& arg_parser) return false; } - std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + - sizeof(BetaDataType) * n + sizeof(YDataType) * m * n; + std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(XBiasDataType) * n + + sizeof(GammaDataType) * n + sizeof(BetaDataType) * n + + sizeof(YDataType) * m * n; float gb_per_sec = num_byte / 1.E6 / ave_time; std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; @@ -221,6 +241,22 @@ bool run(const ck_tile::ArgParser& arg_parser) if(do_validation) { // reference + if(xbias != 0) + { + // add bias before fadd + int M = x_host.mDesc.get_lengths()[0]; + int N = x_host.mDesc.get_lengths()[1]; + for(int idx_m = 0; idx_m < M; ++idx_m) + { + for(int idx_n = 0; idx_n < N; ++idx_n) + { + x_host(idx_m, idx_n) = ck_tile::type_convert( + ck_tile::type_convert(x_host(idx_m, idx_n)) + + ck_tile::type_convert(x_bias_host(idx_n))); + } + } + } + if(fused_add != 0) { // fused pre_add/pre_add_store @@ -254,8 +290,8 @@ bool run(const ck_tile::ArgParser& arg_parser) for(int n_ = 0; n_ < N_; n_++) { // input smooth outlier - acc_(m_, n_) = - acc_(m_, n_) * ck_tile::type_convert(x_scale_host(n_)); + acc_(m_, n_) = acc_(m_, n_) * + ck_tile::type_convert(sm_scale_host(n_)); } } ComputeDataType absmax = static_cast(0); @@ -265,7 +301,11 @@ bool run(const ck_tile::ArgParser& arg_parser) absmax = a > absmax ? a : absmax; } // printf("cpu:absmax:%f\n", absmax); - ComputeDataType y_scale = absmax / static_cast(127.0); + constexpr ComputeDataType kMaxY = + std::is_same::value ? 240.0 + : std::is_same::value ? 127.0 + : 0.0; + ComputeDataType y_scale = absmax / kMaxY; y_scale_host_ref(m_) = ck_tile::type_convert(y_scale); for(int n_ = 0; n_ < N_; n_++) { @@ -308,7 +348,7 @@ bool run(const ck_tile::ArgParser& arg_parser) y_residual_buf.FromDevice(y_residual_host_dev.data()); } - auto [rtol, atol] = get_elimit(); + auto [rtol, atol] = get_elimit(); if(x_stride == n) { @@ -377,16 +417,16 @@ int main(int argc, char* argv[]) std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_o = arg_parser.get_str("prec_o"); - std::string prec_sx = arg_parser.get_str("prec_sx"); + std::string prec_sm = arg_parser.get_str("prec_sm"); std::string prec_sy = arg_parser.get_str("prec_sy"); if(prec_o == "auto") { prec_o = prec_i; } - if(prec_sx == "auto") + if(prec_sm == "auto") { - prec_sx = "fp32"; + prec_sm = "fp32"; } if(prec_sy == "auto") { @@ -395,37 +435,47 @@ int main(int argc, char* argv[]) int save_mv = arg_parser.get_int("save_mv"); // no dynamic quant case - if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32" && save_mv) + if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" && save_mv) { return run(arg_parser) ? 0 : -2; } - else if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32" && + else if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" && !save_mv) { return run(arg_parser) ? 0 : -2; } - else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32" && + else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" && save_mv) { return run(arg_parser) ? 0 : -2; } - else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32" && + else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" && !save_mv) { return run(arg_parser) ? 0 : -2; } // dynamic quant case, only in inference - else if(prec_i == "fp16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32" && + else if(prec_i == "fp16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" && !save_mv) { return run(arg_parser) ? 0 : -2; } - else if(prec_i == "bf16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32" && + else if(prec_i == "bf16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" && !save_mv) { return run(arg_parser) ? 0 : -2; } + else if(prec_i == "fp16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_mv) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_mv) + { + return run(arg_parser) ? 0 : -2; + } return -3; } diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp index a0f2db0e8a478da5a4302fe7439aa1354d3b923a..0538953a580e76322920c3bd1cdb3db78971b591 100644 --- a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -8,35 +8,40 @@ #include "ck_tile/ops/layernorm2d.hpp" #include -template +template struct LayerNormTypeConfig; -template -struct LayerNormTypeConfig +template +struct LayerNormTypeConfig { - using XDataType = ck_tile::half_t; - using YDataType = OutType; - using GammaDataType = ck_tile::half_t; - using BetaDataType = ck_tile::half_t; - using MeanDataType = ck_tile::half_t; - using InvStdDataType = ck_tile::half_t; - using ComputeDataType = float; - using XScaleDataType = XScaleDataType_; - using YScaleDataType = YScaleDataType_; + using XDataType = ck_tile::half_t; + using YDataType = OutType; + using XBiasDataType = ck_tile::half_t; + using GammaDataType = ck_tile::half_t; + using BetaDataType = ck_tile::half_t; + using MeanDataType = ck_tile::half_t; + using InvStdDataType = ck_tile::half_t; + using ComputeDataType = float; + using SmoothScaleDataType = SmoothScaleDataType_; + using YScaleDataType = YScaleDataType_; }; -template -struct LayerNormTypeConfig +template +struct LayerNormTypeConfig { - using XDataType = ck_tile::bf16_t; - using YDataType = OutType; - using GammaDataType = ck_tile::bf16_t; - using BetaDataType = ck_tile::bf16_t; - using MeanDataType = ck_tile::bf16_t; - using InvStdDataType = ck_tile::bf16_t; - using ComputeDataType = float; - using XScaleDataType = XScaleDataType_; - using YScaleDataType = YScaleDataType_; + using XDataType = ck_tile::bf16_t; + using YDataType = OutType; + using XBiasDataType = ck_tile::bf16_t; + using GammaDataType = ck_tile::bf16_t; + using BetaDataType = ck_tile::bf16_t; + using MeanDataType = ck_tile::bf16_t; + using InvStdDataType = ck_tile::bf16_t; + using ComputeDataType = float; + using SmoothScaleDataType = SmoothScaleDataType_; + using YScaleDataType = YScaleDataType_; }; // runtime args @@ -50,13 +55,14 @@ struct layernorm2d_fwd_traits std::string prec_i; // input precision std::string prec_o; // output precision - // if fused_quant == 1, need set prec_sx/prec_sy to proper string, otherwise can set + // if fused_quant == 1, need set prec_sm/prec_sy to proper string, otherwise can set // arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise // can set arbitrary(will skip check) - std::string prec_sx; // x-scale, used for [1*N] input smooth quant + std::string prec_sm; // x-scale, used for [1*N] input smooth quant std::string prec_sy; // y-scale, used for [M*1] output for next layer bool save_mean_var; // + int xbias; // 0:no-bias, 1:add bias int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant }; diff --git a/example/ck_tile/02_layernorm2d/script/smoke_test.sh b/example/ck_tile/02_layernorm2d/script/smoke_test.sh index b7fd354bb8e05647dc66ee9f4699757bc8378d8b..ceaf262bd9eb79308d697f0493cdc8536f4f227a 100755 --- a/example/ck_tile/02_layernorm2d/script/smoke_test.sh +++ b/example/ck_tile/02_layernorm2d/script/smoke_test.sh @@ -1,7 +1,7 @@ #!/bin/sh EXE="$(find . -name tile_example_layernorm2d_fwd -type f | head -n 1)" -for fquant in "" "-fquant=1 -prec_o=int8"; do +for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=1 -prec_o=fp8"; do for pr_i in "fp16" "bf16" ; do for fadd in "0" "1"; do $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13 @@ -27,7 +27,8 @@ $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192 -#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=9120 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547 #$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134 done done diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index 8ae46cadc65fa56ae93dbd8668c453df9a7451f3..bc3799f015267f1784f1ea80d1e20473d754f8c5 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -1,2 +1,2 @@ add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) -add_executable(tile_example_gemm_mem_pipeline EXCLUDE_FROM_ALL gemm_mem_pipeline.cpp) +add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) diff --git a/example/ck_tile/03_gemm/README.md b/example/ck_tile/03_gemm/README.md index e9ffe72a9152d22b62680369ac03ab569c0eecce..4c16f13cefdcfa474c534ce16fc8e50d742999eb 100644 --- a/example/ck_tile/03_gemm/README.md +++ b/example/ck_tile/03_gemm/README.md @@ -11,9 +11,9 @@ sh ../script/cmake-ck-dev.sh ../ # The basic pipeline method on the gemm calculation make tile_example_gemm_basic -j # The memory bound pipeline on the gemm calculation -make tile_example_gemm_mem_pipeline -j +make tile_example_gemm_universal -j ``` -This will result in an executable `build/bin/tile_example_gemm_basic` +This will result in an executable `build/bin/tile_example_gemm_basic` & `build/bin/tile_example_gemm_universal` ## example ``` @@ -22,6 +22,9 @@ args: -m m dimension (default:1024) -n n dimension (default:2048) -k k dimension (default:64) + -a_layout Tensor A data layout (default: R) + -b_layout Tensor B data layout (default: R) + -c_layout Tensor C data layout (default: R) -stride_a Tensor A stride (default:0) -stride_b Tensor B stride (default:0) -stride_c Tensor C stride (default:0) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index b7d869344238d1ecf3a2518683d961e2d5292c96..2e04780eb06fda3c174bc2cf4edb1cbe0782fd43 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -9,29 +9,29 @@ #include #include -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/gemm.hpp" #include "ck_tile/host.hpp" #include "gemm_basic.hpp" -template -float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) +template +float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. constexpr bool kPadM = false; constexpr bool kPadN = false; constexpr bool kPadK = false; - constexpr bool kTilePermute = false; - // The rank and permutation will also be generate out by the CodeGen part. - constexpr ck_tile::index_t kOutputRank = 2; - constexpr int kBlockPerCu = 1; // This part comes from the Codegen constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t N_Tile = 128; - constexpr ck_tile::index_t K_Tile = 32; + constexpr ck_tile::index_t K_Tile = 64; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; @@ -39,59 +39,47 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 8; - - // Whether doing the CShuffle (transpose before the global memory), depending on the output - // layout. - constexpr bool CShuffleEpilogue = - std::is_same_v; + constexpr ck_tile::index_t K_Warp_Tile = 16; using CodegenGemmShape = ck_tile::TileGemmShape, ck_tile::sequence, ck_tile::sequence>; - using TilePartitioner = ck_tile::GemmTilePartitioner; - - using GemmEpilogue = std::conditional_t< - CShuffleEpilogue, - ck_tile::CShuffleEpilogue>, - ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>>; + using TilePartitioner = ck_tile::GemmTile1DPartitioner; using CodegenGemmTraits = ck_tile::TileGemmTraits; using CodegenPipelineProblem = ck_tile:: GemmPipelineProblem; - using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy; - using CodegenGemmPipeline = - ck_tile::GemmPipelineAGmemBGmemCRegV1; + using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKargs(args.p_a, - args.p_b, - args.p_c, - args.M, - args.N, - args.K, - args.stride_A, - args.stride_B, - args.stride_C); - - const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); constexpr dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + if(s.log_level_ > 0) { std::cout << "Launching kernel with args:" @@ -108,4 +96,46 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) #include "run_gemm_example.inc" +int run_gemm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + std::string data_type = arg_parser.get_str("prec"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + if(a_layout == "R" && b_layout == "C") + { + if(data_type == "fp16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "fp8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } +} + int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm_basic.hpp index 23e99bc2a88c7584b6ce9be2e59357cb801bf068..5fa94f5f7284437fbd46f8bd76245f7a73ca4625 100644 --- a/example/ck_tile/03_gemm/gemm_basic.hpp +++ b/example/ck_tile/03_gemm/gemm_basic.hpp @@ -8,6 +8,27 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +#define CK_TILE_PIPELINE_COMPUTE 1 +#define CK_TILE_PIPELINE_MEMORY 2 + +#ifndef CK_TILE_PIPELINE_DEFAULT +#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE +#endif + +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave +#else +#error "unsupported CK_TILE_PIPELINE_DEFAULT value" +#endif template struct GemmBasicTypeConfig; @@ -22,6 +43,33 @@ struct GemmBasicTypeConfig // ToDo: Add more bias config to support different categories of GEMM. }; +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using AccDataType = float; + using CDataType = ck_tile::bf16_t; +}; + +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + template struct DataTypeTraits; @@ -43,37 +91,32 @@ struct DataTypeTraits static constexpr const char* name = "fp16"; }; -using Types = GemmBasicTypeConfig; +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; -// Specific type aliases for easy access -using ADataType = Types::ADataType; -using BDataType = Types::BDataType; -using AccDataType = Types::AccDataType; -using CDataType = Types::CDataType; +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; -struct gemm_basic_args +template <> +struct DataTypeTraits { - const void* p_a; - const void* p_b; - void* p_c; - ck_tile::index_t kbatch; - ck_tile::index_t M; - ck_tile::index_t N; - ck_tile::index_t K; - ck_tile::index_t stride_A; - ck_tile::index_t stride_B; - ck_tile::index_t stride_C; + static constexpr const char* name = "bf8"; }; auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; - arg_parser.insert("b", "1", "batch size") - .insert("m", "3840", "m dimension") + arg_parser.insert("m", "3840", "m dimension") .insert("n", "4096", "n dimension") .insert("k", "2048", "k dimension") .insert("a_layout", "R", "A tensor data layout - Row by default") - .insert("b_layout", "R", "B tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Column by default") .insert("c_layout", "R", "C tensor data layout - Row by default") .insert("stride_a", "0", "Tensor A stride") .insert("stride_b", "0", "Tensor B stride") @@ -82,11 +125,12 @@ auto create_args(int argc, char* argv[]) .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") - .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); } // host API -float gemm_calc(gemm_basic_args args, const ck_tile::stream_config& s); +float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp b/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp deleted file mode 100644 index ff9d8bad3275bfb1f8219414a8c522fe7b3cdeca..0000000000000000000000000000000000000000 --- a/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp +++ /dev/null @@ -1,188 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include -#include -#include -#include -#include - -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/gemm.hpp" -#include "ck_tile/host.hpp" -#include "gemm_basic.hpp" - -template -float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) -{ - // ToDo: This will be modified by the codegen code later. - constexpr ck_tile::index_t M_Tile = 128; - constexpr ck_tile::index_t N_Tile = 128; - constexpr ck_tile::index_t K_Tile = 32; - - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 8; - - // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. - constexpr bool kPadM = true; - constexpr bool kPadN = true; - constexpr bool kPadK = true; - - constexpr int kBlockPerCu = 1; - - // =============================================== - - using GemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; - using TilePartitioner = ck_tile::GemmTilePartitioner; - - using GemmEpilogue = ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>; - - using Traits = ck_tile::TileGemmTraits; - - using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< - ck_tile::GemmPipelineProblem>; - - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - - using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< - ck_tile::UniversalGemmPipelineProblem>; - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKargs(args.p_a, - args.p_b, - args.p_c, - args.M, - args.N, - args.K, - args.stride_A, - args.stride_B, - args.stride_C); - - const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); - constexpr dim3 blocks = Kernel::BlockSize(); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } - - ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; - }; - - if(has_hot_loop) - { - // Tail pipeline One to Seven - if(tail_num == ck_tile::TailNumber::One) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Full) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - - if constexpr(BaseGemmPipeline::PrefetchStages > 2) - { - if(tail_num == ck_tile::TailNumber::Two) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 3) - { - if(tail_num == ck_tile::TailNumber::Three) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 4) - { - if(tail_num == ck_tile::TailNumber::Four) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 5) - { - if(tail_num == ck_tile::TailNumber::Five) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 6) - { - if(tail_num == ck_tile::TailNumber::Six) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 7) - { - if(tail_num == ck_tile::TailNumber::Seven) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - } - else - { - // Tail number always Full - #PrefetchStages - if(tail_num == ck_tile::TailNumber::Full) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - std::ostringstream err; - err << "When there's no hot loop, this tail number \"" << tail_num - << "\" is not supported! " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } - } - - return ave_time; -} - -#include "run_gemm_example.inc" - -int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 8db131738b72366664eefba58e0a7092b6c44f30..028f8a44c3e249e565470668fc82530203d2cf15 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -1,8 +1,37 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once -template +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& c_m_n_dev_buf, @@ -16,11 +45,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, int n_warmup, int n_repeat) { - gemm_basic_args args; - args.p_a = a_m_k_dev_buf.GetDeviceBuffer(); - args.p_b = b_k_n_dev_buf.GetDeviceBuffer(); - args.p_c = c_m_n_dev_buf.GetDeviceBuffer(); - args.kbatch = kbatch; + ck_tile::GemmHostArgs args; + args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); + args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); + args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + args.k_batch = kbatch; args.M = M; args.N = N; args.K = K; @@ -28,26 +57,31 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, args.stride_B = stride_B; args.stride_C = stride_C; - float ave_time = gemm_calc( + float ave_time = gemm_calc( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); - std::string op_name{"Gemm{MemBoundPipeline}"}; - std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time; - std::cout << "Run " << op_name << "kernel with M =" << M << " N =" << N << " K =" << K + std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C + << " A_Layout =" << ALayout::name + << " B_Layout =" << BLayout::name + << " C_Layout =" << CLayout::name + << " A Type = " << DataTypeTraits::name + << " B Type = " << DataTypeTraits::name + << " C Type = " << DataTypeTraits::name << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; return ave_time; } -template +template int run_gemm_example_with_layouts(int argc, char* argv[], const ALayout a_layout = ALayout{}, @@ -58,6 +92,11 @@ int run_gemm_example_with_layouts(int argc, if(!result) return -1; + using ADataType = typename GemmBasicTypeConfig::ADataType; + using BDataType = typename GemmBasicTypeConfig::BDataType; + using CDataType = typename GemmBasicTypeConfig::CDataType; + using AccDataType = typename GemmBasicTypeConfig::AccDataType; + ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t K = arg_parser.get_int("k"); @@ -66,55 +105,22 @@ int run_gemm_example_with_layouts(int argc, ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); - ck_tile::index_t batch_size = arg_parser.get_int("b"); - int n_warmup = arg_parser.get_int("warmup"); - int n_repeat = arg_parser.get_int("repeat"); - - using namespace ck_tile::literals; - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if constexpr(std::is_same_v) - { - return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - auto f_get_default_stride = [](std::size_t row, - std::size_t col, - std::size_t stride, - auto layout) { - if(stride == 0) - { - // give a chance if stride is zero, return a default packed stride - if constexpr(std::is_same_v) - { - return col; - } - else - { - return row; - } - } - else - return stride; - }; - - stride_A = f_get_default_stride(M, K, stride_A, a_layout); - stride_B = f_get_default_stride(K, N, stride_B, b_layout); - stride_C = f_get_default_stride(M, N, stride_C, CLayout{}); - - ck_tile::HostTensor a_m_k(f_host_tensor_descriptor(M, K, stride_A, a_layout)); - ck_tile::HostTensor b_k_n(f_host_tensor_descriptor(K, N, stride_B, b_layout)); + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); ck_tile::HostTensor c_m_n_dev_result( - f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); // TODO: add different init types - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); @@ -127,7 +133,8 @@ int run_gemm_example_with_layouts(int argc, c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - invoke_gemm(a_m_k_dev_buf, + invoke_gemm(a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_dev_buf, M, @@ -136,7 +143,7 @@ int run_gemm_example_with_layouts(int argc, stride_A, stride_B, stride_C, - batch_size, + kbatch, n_warmup, n_repeat); @@ -146,72 +153,84 @@ int run_gemm_example_with_layouts(int argc, if(arg_parser.get_int("v") == 1) { ck_tile::HostTensor c_m_n_host_ref( - f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); c_m_n_host_ref.SetZero(); ck_tile::reference_gemm( a_m_k, b_k_n, c_m_n_host_ref); - - pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref); - + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol + (K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; } else if(arg_parser.get_int("v") == 2) { ck_tile::HostTensor c_m_n_gpu_ref( - f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); c_m_n_gpu_ref.SetZero(); c_m_n_gpu_buf_ref.SetZero(); + ADataType* d_A; + BDataType* d_B; + CDataType* d_C; + + ck_tile::hip_check_error(hipMalloc(&d_A, M * K * sizeof(ADataType))); + ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType))); + ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType))); + + ck_tile::hip_check_error(hipMemcpy(d_A, + a_m_k_dev_buf.GetDeviceBuffer(), + M * K * sizeof(ADataType), + hipMemcpyHostToDevice)); + ck_tile::hip_check_error(hipMemcpy(d_B, + b_k_n_dev_buf.GetDeviceBuffer(), + N * K * sizeof(BDataType), + hipMemcpyHostToDevice)); + ck_tile::reference_gemm_gpu( - a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_gpu_buf_ref, M, N, K, stride_A, stride_B, stride_C); + CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); - c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); - pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref); + ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(), + d_C, + M * N * sizeof(CDataType), + hipMemcpyDeviceToHost)); + ck_tile::hip_check_error(hipFree(d_A)); + ck_tile::hip_check_error(hipFree(d_B)); + ck_tile::hip_check_error(hipFree(d_C)); + + c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); + const float max_accumulated_value = + *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol + (K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_gpu_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; } return pass; } - -int run_gemm_example(int argc, char* argv[]) -{ - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return -1; - - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - - std::string a_layout = arg_parser.get_str("a_layout"); - std::string b_layout = arg_parser.get_str("b_layout"); - - if(a_layout == "R" && b_layout == "R") - { - return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); - } - else if(a_layout == "R" && b_layout == "C") - { - return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); - } - else if(a_layout == "C" && b_layout == "C") - { - return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); - } - else if(a_layout == "C" && b_layout == "R") - { - return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); - } - else - { - throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); - } -} diff --git a/example/ck_tile/03_gemm/script/benchmark_basic.sh b/example/ck_tile/03_gemm/script/benchmark_basic.sh new file mode 100755 index 0000000000000000000000000000000000000000..a1646da5bd0d3df6fcc06dad58a5a00cfc43c23b --- /dev/null +++ b/example/ck_tile/03_gemm/script/benchmark_basic.sh @@ -0,0 +1,14 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)" +VALID=1 + + +for b_matrix_layout in "C"; do + for m in "64" "512" "1024" "2048"; do + for n in "512" "1024" "2048"; do + for k in "64" "512" "1024" "2048"; do + $EXE -prec=fp16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done diff --git a/example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh b/example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh b/example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh b/example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh new file mode 100644 index 0000000000000000000000000000000000000000..21462616be3681bc696baef0d4554dbf77b64dac --- /dev/null +++ b/example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh @@ -0,0 +1,14 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)" +VALID=1 + + +for b_matrix_layout in "C"; do + for m in "64" "512" "1024" "2048"; do + for n in "512" "1024" "2048"; do + for k in "64" "512" "1024" "2048"; do + $EXE -prec=fp8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh new file mode 100755 index 0000000000000000000000000000000000000000..c4cf4ddcbfba8815f19456f857827fdfdaba1ce6 --- /dev/null +++ b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh @@ -0,0 +1,13 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)" +VALID=1 + +for b_matrix_layout in "C"; do + for m in "512" "1024" "2048" "4096"; do + for n in "512" "1024" "2048"; do + for k in "512" "1024" "2048"; do + $EXE -prec=fp16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done diff --git a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh new file mode 100644 index 0000000000000000000000000000000000000000..903b4a3c0ff385408cbb47aff660f5c57c2b802f --- /dev/null +++ b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh @@ -0,0 +1,13 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)" +VALID=1 + +for b_matrix_layout in "C"; do + for m in "512" "1024" "2048" "4096"; do + for n in "512" "1024" "2048"; do + for k in "512" "1024" "2048"; do + $EXE -prec=bf16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh new file mode 100644 index 0000000000000000000000000000000000000000..8c92c2e99116047506b5f417433afc1ca11e2d0a --- /dev/null +++ b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh @@ -0,0 +1,13 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)" +VALID=1 + +for b_matrix_layout in "C"; do + for m in "512" "1024" "2048" "4096"; do + for n in "512" "1024" "2048"; do + for k in "512" "1024" "2048"; do + $EXE -prec=bf8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh new file mode 100644 index 0000000000000000000000000000000000000000..e238006c7d0e0cf15c63dc0cfb456415d77f24d1 --- /dev/null +++ b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh @@ -0,0 +1,13 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)" +VALID=1 + +for b_matrix_layout in "C"; do + for m in "512" "1024" "2048" "4096"; do + for n in "512" "1024" "2048"; do + for k in "512" "1024" "2048"; do + $EXE -prec=fp8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/run_full_test.sh b/example/ck_tile/03_gemm/script/run_full_test.sh index 2e2e7fdf90524bfdaedb240b4bc13c6556be9225..45bd1bed614f9e4c0f4583ff3f91bc129f5dd545 100755 --- a/example/ck_tile/03_gemm/script/run_full_test.sh +++ b/example/ck_tile/03_gemm/script/run_full_test.sh @@ -19,7 +19,27 @@ echo 'Host name: ' $host_name export GPU_arch=$4 echo 'GPU_arch: ' $GPU_arch +function print_log_header(){ + rm -f $1; + echo 'On branch ' $3 &> $1; + echo 'Node name: ' $4 >> $1; + # get GPU architecture and compute units from rocminfo + echo -n "GPU_arch: " >> $1; rocminfo | grep "Name:" | grep "gfx" >> $1; + rocminfo | grep "Compute Unit:" >> $1; + hipcc --version | grep -e 'HIP version' >> $1; + echo 'Environment type: ' $2 >> $1; + /opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> $1; +} + # run verification tests -example/ck_tile/03_gemm/script/smoke_test.sh +example/ck_tile/03_gemm/script/smoke_test_basic.sh +example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh + +# run performance benchmarks +export gemm_basic_log="perf_tile_gemm_basic_fp16_$GPU_arch.log" +print_log_header $gemm_basic_log $env_type $branch $host_name +example/ck_tile/03_gemm/script/benchmark_basic.sh 2>&1 | tee -a $gemm_basic_log -# We do not have a performance benchmark for gemm yet. Will add it in the future. \ No newline at end of file +export gemm_mem_pipeline_log="perf_tile_gemm_mem_pipeline_fp16_$GPU_arch.log" +print_log_header $gemm_mem_pipeline_log $env_type $branch $host_name +example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh 2>&1 | tee -a $gemm_mem_pipeline_log diff --git a/example/ck_tile/03_gemm/script/smoke_test.sh b/example/ck_tile/03_gemm/script/smoke_test.sh deleted file mode 100755 index 4d9a64bf40dc38f19dd7612ce1dfb9cd14e5ab45..0000000000000000000000000000000000000000 --- a/example/ck_tile/03_gemm/script/smoke_test.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash -EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)" -KNAME=1 - -export CK_WARMUP=0 -export CK_REPEAT=1 - -COMMON_ARGS='-v=2 -warmup=0 -repeat=1' - -run_fp16_tests() { - for batch in 1 2; do - for m in 128 1024; do - for n in 128 2048; do - for k in 32 64; do - - $EXE -b=$batch -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -e=1e-5 -prec=fp16 $COMMON_ARGS - if [ $? -eq 0 ]; then - echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully." - else - echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly." - # Optionally, exit or break if you need to halt further execution - # exit 1 - fi - - done - done - done - done -} - -set -x - -run_fp16_tests - -set +x \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/smoke_test_basic.sh b/example/ck_tile/03_gemm/script/smoke_test_basic.sh new file mode 100755 index 0000000000000000000000000000000000000000..7ca6759f420bc0f0af3868ac2f25644f2e6fd689 --- /dev/null +++ b/example/ck_tile/03_gemm/script/smoke_test_basic.sh @@ -0,0 +1,36 @@ +#!/bin/bash +EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)" +KNAME=1 + +export CK_WARMUP=0 +export CK_REPEAT=1 + +COMMON_ARGS='-v=2 -warmup=0 -repeat=1' + +run_tests() { + for m in 128 1024; do + for n in 128 2048; do + for k in 64 128; do + + $EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS + if [ $? -eq 0 ]; then + echo "Success: Test with m=$m, n=$n, k=$k executed successfully." + else + echo "Error: Test with m=$m, n=$n, k=$k failed to execute properly." + # Optionally, exit or break if you need to halt further execution + # exit 1 + fi + + done + done + done +} + +set -x + +run_tests "fp16" +run_tests "bf16" +run_tests "fp8" +run_tests "bf8" + +set +x diff --git a/example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh b/example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh new file mode 100755 index 0000000000000000000000000000000000000000..951f8aa63ae5bf9c6aeb97106baa8af644eb8c63 --- /dev/null +++ b/example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh @@ -0,0 +1,36 @@ +#!/bin/bash +EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)" +KNAME=1 + +export CK_WARMUP=0 +export CK_REPEAT=1 + +COMMON_ARGS='-v=2 -warmup=0 -repeat=1' + +run_tests() { + for m in 512 1024; do + for n in 512 2048; do + for k in 512 1024; do + + $EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS + if [ $? -eq 0 ]; then + echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully." + else + echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly." + # Optionally, exit or break if you need to halt further execution + # exit 1 + fi + + done + done + done +} + +set -x + +run_tests "fp16" +run_tests "bf16" +run_tests "fp8" +run_tests "bf8" + +set +x diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..08a9cdb24b55f41695280350b565cc9d04747638 --- /dev/null +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -0,0 +1,354 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "gemm_basic.hpp" + +template +float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +{ +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) + // Memory friendly for Interwave scheduler + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 32; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 4; + constexpr ck_tile::index_t N_Warp = 1; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 8; +#endif +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) + // Compute friendly for Intrawave scheduler + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; +#endif + + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr bool TransposeC = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + // =============================================== + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile:: + TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE; + + const ck_tile::index_t k_grain = args.k_batch * K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = + GEMM_PIPELINE; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + if(has_hot_loop) + { +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "For compute pipeline tail number should always be Full, but have \"" << tail_num + << "\" which is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) + // Tail pipeline One to Seven + if(tail_num == ck_tile::TailNumber::One) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + + if constexpr(BaseGemmPipeline::PrefetchStages > 2) + { + if(tail_num == ck_tile::TailNumber::Two) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 3) + { + if(tail_num == ck_tile::TailNumber::Three) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 4) + { + if(tail_num == ck_tile::TailNumber::Four) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 5) + { + if(tail_num == ck_tile::TailNumber::Five) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 6) + { + if(tail_num == ck_tile::TailNumber::Six) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 7) + { + if(tail_num == ck_tile::TailNumber::Seven) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } +#endif + } + else + { + // Tail number always Full - #PrefetchStages + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "When there's no hot loop, this tail number \"" << tail_num + << "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + + return ave_time; +} + +#include "run_gemm_example.inc" + +int run_gemm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + std::string data_type = arg_parser.get_str("prec"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + if(a_layout == "R" && b_layout == "R") + { + if(data_type == "fp16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else if(data_type == "bf16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else if(data_type == "fp8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else if(data_type == "bf8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } + } + else if(a_layout == "R" && b_layout == "C") + { + if(data_type == "fp16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "fp8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } + } + else if(a_layout == "C" && b_layout == "C") + { + if(data_type == "fp16") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else if(data_type == "bf16") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else if(data_type == "fp8") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else if(data_type == "bf8") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } + } + else if(a_layout == "C" && b_layout == "R") + { + if(data_type == "fp16") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else if(data_type == "bf16") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else if(data_type == "fp8") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else if(data_type == "bf8") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } +} + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/ck_tile/05_reduce/reduce.cpp b/example/ck_tile/05_reduce/reduce.cpp index 005541dc62bc428df3011df44234754ee663f9b2..602661f7791a69192a8df3b9a63f5eb48eca50da 100644 --- a/example/ck_tile/05_reduce/reduce.cpp +++ b/example/ck_tile/05_reduce/reduce.cpp @@ -52,7 +52,7 @@ bool run(const ck_tile::ArgParser& arg_parser) // using WarpTile = ck_tile::sequence<1, 512>; // using Vector = ck_tile::sequence<1, 8>; - constexpr ck_tile::index_t kBlockSize = 512; + constexpr ck_tile::index_t kBlockSize = 256; constexpr ck_tile::index_t kBlockPerCu = 1; ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{})); std::cout << "grid size " << kGridSize << std::endl; diff --git a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.cpp b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.cpp index 93c662a288fa804d72ad703152d8d83920ec920a..e5ded0ef3b6938ca748ccc065582943b47ea9c8b 100644 --- a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.cpp +++ b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.cpp @@ -40,7 +40,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t, else if(t.permute.compare("0,1,3,4,2,5") == 0) { constexpr matrix_core_permute_style pstyle = - matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv; + matrix_core_permute_style::b_nr_kr_kw_nw_kv; using Kernel = matrix_core_swizzle_kernel; @@ -83,7 +83,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t, else if(t.permute.compare("0,1,3,4,2,5") == 0) { constexpr matrix_core_permute_style pstyle = - matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv; + matrix_core_permute_style::b_nr_kr_kw_nw_kv; using Kernel = matrix_core_swizzle_kernel; diff --git a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp index 60ac103ec3d1f23a535813bdebd0048f20e854d1..28f4c452bcae42b8679d5d1b3b4777d5053ff205 100644 --- a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp +++ b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp @@ -42,8 +42,8 @@ enum class matrix_core_permute_style { permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6 permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6 - permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5 - permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv, + b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5 + b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv, }; // assume this is B matrix, originally we have batch*n*k @@ -203,7 +203,7 @@ struct matrix_core_swizzle_kernel else { // clang-format off - // permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten + // b_nr_kr_kw_nw_kv or b_nr_kr_waveflatten constexpr index_t Kv = Alignment; constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; @@ -332,7 +332,7 @@ struct matrix_core_swizzle_kernel make_tuple(sequence<0>{}, sequence<1>{})); return tmp_1; #else - // permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv, + // b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv, constexpr index_t kv = Alignment; constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; @@ -376,13 +376,13 @@ struct matrix_core_swizzle_kernel else { #if MERGE_2D_013425 - // permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv + // b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv return make_tile_window(dst_view, make_tuple(number{}, number{}), {i_n * NPerBlock, i_k * KPerBlock}, get_dst_dist()); #else - // permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv + // b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv constexpr index_t kv = Alignment; constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; diff --git a/example/ck_tile/06_permute/permute.cpp b/example/ck_tile/06_permute/permute.cpp index af95b64e69e23f55cbed079f0733743d4b160993..477ae370b9dbb6ab2dae9b58cce5275463cb291c 100644 --- a/example/ck_tile/06_permute/permute.cpp +++ b/example/ck_tile/06_permute/permute.cpp @@ -264,7 +264,7 @@ bool run(const ck_tile::ArgParser& arg_parser) { if(arg_parser.get_str("perm") == std::string("0,1,3,4,2,5")) { - // permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5 + // b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5 matrix_core_swizzle_traits t; t.data_type = data_type; t.permute = arg_parser.get_str("perm"); diff --git a/example/ck_tile/10_rmsnorm2d/CMakeLists.txt b/example/ck_tile/10_rmsnorm2d/CMakeLists.txt index a3ff8fdf4595715a1620e3a8f46de870ecd06bb1..5684c9b2e00f30248e39e2a0f93d08a9cb2e33bd 100644 --- a/example/ck_tile/10_rmsnorm2d/CMakeLists.txt +++ b/example/ck_tile/10_rmsnorm2d/CMakeLists.txt @@ -1,16 +1,39 @@ +set(RMSNORM2D_FWD_KNOWN_APIS "fwd;bwd") +set(RMSNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING + "semicolon-separated list of APIs to generate (${RMSNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".") +if(RMSNORM2D_FWD_ENABLE_APIS STREQUAL "all") + set(RMSNORM2D_FWD_ENABLE_APIS ${RMSNORM2D_FWD_KNOWN_APIS}) +endif() + +# generate a list of kernels, but not actually emit files at config sta +execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api ${RMSNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --list_blobs + RESULT_VARIABLE ret +) +if(ret AND NOT ret EQUAL 0) + message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}") +endif() + +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/rmsnorm2d_fwd_blobs.txt RMSNORM2D_FWD_GEN_BLOBS) + +add_custom_command( + OUTPUT ${RMSNORM2D_FWD_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api ${RMSNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --gen_blobs +) + set(TILE_RMSNORM2D_FWD "tile_rmsnorm2d_fwd") -# not using add_example_executable() to add this target, since we don't want this to have -# to be included in "make all/install/check" + message("adding ${TILE_RMSNORM2D_FWD}") -file(GLOB INSTANCE_SRCS instances/*.cpp) add_executable(${TILE_RMSNORM2D_FWD} EXCLUDE_FROM_ALL rmsnorm2d_fwd.cpp) target_include_directories(${TILE_RMSNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${INSTANCE_SRCS}) +target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${RMSNORM2D_FWD_GEN_BLOBS}) set(TILE_RMSNORM2D_FWD_COMPILE_OPTIONS) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations -list(APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +list(APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal --offload-compress) target_compile_options(${TILE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS}) diff --git a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp index 34df7b74fa3c710becc10670b54c76f910317781..48c150009e291b6baab28066d2185fc62487bc94 100644 --- a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp @@ -1,6 +1,7 @@ #include "ck_tile/host.hpp" #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/rmsnorm2d.hpp" #include @@ -36,10 +37,12 @@ bool run(const ck_tile::ArgParser& arg_parser) assert(stride >= n); - using XDataType = DataType; - using YDataType = DataType; - using GammaDataType = DataType; - using InvRmsDataType = ck_tile::null_type; + using XDataType = DataType; + using YDataType = DataType; + using GammaDataType = DataType; + using InvRmsDataType = ck_tile::null_type; + using SmoothScaleDataType = ck_tile::null_type; + using YScaleDataType = ck_tile::null_type; using ComputeDataType = float; @@ -68,30 +71,49 @@ bool run(const ck_tile::ArgParser& arg_parser) using BlockTile = ck_tile::sequence<2, 128>; using WarpTile = ck_tile::sequence<1, 64>; using Vector = ck_tile::sequence<1, 1>; + using Shape = ck_tile::Generic2dBlockShape; + + using PipelineTraits = + ck_tile::Rmsnorm2dFwdTraits; // fuse quant - using Shape = ck_tile::Generic2dBlockShape; using Problem = ck_tile::Rmsnorm2dFwdPipelineProblem; + PipelineTraits>; using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass; using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass; using Pipeline = std::conditional_t; - using Kernel = ck_tile::Rmsnorm2dFwd; + + using Default2DEpilogueProblem = ck_tile:: + Default2DEpilogueProblem; + using Default2DEpilogue = ck_tile::Default2DEpilogue; + + using Kernel = ck_tile::Rmsnorm2dFwd; ck_tile::Rmsnorm2dFwdHostArgs args{x_buf.GetDeviceBuffer(), + nullptr, + nullptr, gamma_buf.GetDeviceBuffer(), y_buf.GetDeviceBuffer(), nullptr, + nullptr, + nullptr, epsilon, m, n, + stride, + stride, + stride, stride}; auto kargs = Kernel::MakeKargs(args); diff --git a/example/ck_tile/10_rmsnorm2d/generate.py b/example/ck_tile/10_rmsnorm2d/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..dadb2268b2e50639a603f4f17915e6508eca1186 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/generate.py @@ -0,0 +1,683 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import argparse +from enum import IntEnum +from pathlib import Path +import sys +from typing import List, Optional, Any +import functools +import itertools +import copy +from dataclasses import dataclass + + +def get_if_str(idx, total, lase_else = True): + if idx == 0: + return 'if' + elif idx < total - 1: + return 'else if' + else: + if lase_else: + return 'else' + else: + return 'else if' + +FUSED_ADD_ENUM_STR_MAP = [ + 'no', + 'pras', # pre-norm + 'pra' ] # post-norm + +FUSED_FUSED_SWEEP_STR_MAP = [ + 'no', + 'sdquant', # smooth dynamic quant + 'dquant' ] # dynamic quant (without sm_scale) + +DATA_TYPE_MAP = {'fp32' : 'float', + 'fp16' : 'ck_tile::fp16_t', + 'bf16' : 'ck_tile::bf16_t', + 'int8' : 'ck_tile::int8_t', + 'fp8' : 'ck_tile::fp8_t'} + +def BOOL_MAP(b_) -> str: + if b_: + return 'true' + else: + return 'false' + + +class rmsnorm_fwd_codegen: + API_TRAITS_DEFINE = """ +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct rmsnorm2d_fwd_traits_ +{ + using XDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + using SmoothScaleDataType = ck_tile::remove_cvref_t; + using YScaleDataType = ck_tile::remove_cvref_t; + + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr ck_tile::index_t total_warps = + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + + // num of warps along m + static constexpr ck_tile::index_t BlockWarps_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return total_warps * (warpSize / ThreadPerBlock_N_); + } + else + { + // static_assert(warpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / warpSize); + } + }(); + + // num of warps along n + static constexpr ck_tile::index_t BlockWarps_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N_ % warpSize == 0); + return ThreadPerBlock_N_ / warpSize; + } + }(); + + static constexpr ck_tile::index_t Repeat_M = Repeat_M_; + static constexpr ck_tile::index_t Repeat_N = Repeat_N_; + + static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; + static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; + + static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; + static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + + using BlockTile = ck_tile::sequence; + using BlockWarps = ck_tile::sequence; + using WarpTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + + using Shape = ck_tile::Generic2dBlockShape; + + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveInvRms = kSaveInvRms_; + static constexpr bool kTwoPass = kTwoPass_; + static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_; + static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_; +}; + +template +using traits_ = rmsnorm2d_fwd_traits_; +""" + + API_COMMON_HEADER = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "rmsnorm2d_fwd.hpp" +#include +#include + +#pragma once + +using S = ck_tile::stream_config; +using A = rmsnorm2d_fwd_args; + +{F_traits_define} + +template +float rmsnorm2d_fwd_(const S& s, A a) +{{ + using XDataType = typename Traits_::XDataType; + using YDataType = typename Traits_::YDataType; + using SmoothScaleDataType = typename Traits_::SmoothScaleDataType; + using YScaleDataType = typename Traits_::YScaleDataType; + using ComputeDataType = typename RmsnormTypeConfig::ComputeDataType; + + using PipelineTraits = + ck_tile::Rmsnorm2dFwdTraits(Traits_::kFusedAdd), + static_cast(Traits_::kFusedQuant)>; + + using PipelineProblem = + ck_tile::Rmsnorm2dFwdPipelineProblem::XDataType, + typename RmsnormTypeConfig::GammaDataType, + typename RmsnormTypeConfig::ComputeDataType, + typename RmsnormTypeConfig::YDataType, + typename RmsnormTypeConfig::InvRmsDataType, + typename RmsnormTypeConfig::SmoothScaleDataType, + typename RmsnormTypeConfig::YScaleDataType, + typename Traits_::Shape, + PipelineTraits>; + + using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass; + using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass; + using Pipeline = std::conditional_t; + + using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem; + using Default2DEpilogue = ck_tile::Default2DEpilogue; + + static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1; + using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem>; + + using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue; + + using Epilogue = std::conditional_t; + + using Kernel = ck_tile::Rmsnorm2dFwd; + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + + auto kargs = Kernel::MakeKargs(a); + if(s.log_level_ > 0) + std::cout << ", " << Kernel::GetName() << std::flush; + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); +}} + +""" + + API_BASE = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "rmsnorm2d_fwd.hpp" + +{F_traits_define} + +// Note: this internal API only declare, not define here, otherwise will block `make -j` +template +float rmsnorm2d_fwd_(const ck_tile::stream_config& s, rmsnorm2d_fwd_args a); + +float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, + rmsnorm2d_fwd_args a, + const ck_tile::stream_config& s) +{{ + float r = -1; +{F_dispatch} + return r; +}} + +""" + + INSTANCE_BASE = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_api_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +{F_instance_def} +// clang-format on + +""" + + API_PER_DTYPE = """ + {F_if}(t.prec_i == \"{F_i_type}\" && t.prec_o == \"{F_o_type}\"){{ +{F_per_n_case} + }} +""" + API_PER_N_CASE = """ + {F_if} {F_N_COND} {{ +{F_inner_dispatch} + }} +""" + API_INNER_CASE = """ + {F_if} {F_VEC_COND} + r={F_instance_func}(s, a); +""" + + def __init__(self, working_path, kernel_filter): + self.working_path = working_path + self.kernel_filter = kernel_filter + + class k_fuesd_add_enum(IntEnum): + F_NO_ADD = 0 + F_PRE_ADD = 1 + F_PRE_ADD_STORE_RESIDUAL = 2 + + class k_fused_sweep_enum(IntEnum): + F_NO_SWEEP = 0 + F_RENORM = 1 + F_DYNAMIC_QUANT = 2 + + @dataclass + class k_traits: + F_kPadN : bool + F_kSaveMeanInvStd : bool + F_kTwoPass : bool + F_kFusedAdd : Any + F_kFusedQuant : Any + + @dataclass + class k_shape: + F_BlockTile : List[int] + F_WarpPerBlock : List[int] + F_WarpTile : List[int] + F_Vector_ : List[int] + @property + def F_BlockSize(self) -> int: + return functools.reduce(lambda a, b: a*b, self.F_WarpTile) + + @dataclass + class k_problem: + F_XDataType : str + F_GammaDataType : str + F_ComputeDataType : str + F_YDataType : str + F_InvRmsDataType : str + F_BlockShape : str + F_Traits : Any #k_traits + + @dataclass + class k_pipeline_one_pass: + F_Problem : Any #k_problem + + @dataclass + class k_pipeline_two_pass: + F_Problem : Any #k_problem + + @dataclass + class default_2d_epilogue_problem: + F_AccDataType : str + F_ODataType : str + F_kPadM : bool + F_kPadN : bool + + @dataclass + class default_2d_epilogue: + F_problem : Any + + @dataclass + class k_kernel: + F_pipeline : Any + F_epilogue : Any + + @dataclass + class h_traits: + F_XDataType : str + F_YDataType : str + F_SmoothScaleDataType : str + F_YScaleDataType : str + F_Repeat_M : int + F_Repeat_N : int + F_ThreadPerBlock_M : int + F_ThreadPerBlock_N : int + F_Vector_N : int + F_kPadN : bool + F_kSaveInvRms : bool + F_kTwoPass : bool + F_kFusedAdd : int + F_kFusedQuant : int + + @property + def trait_name(self) ->str: + t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' + t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveInvRms):5}' + t_ += f', {BOOL_MAP(self.F_kTwoPass):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}' + return t_ + + # string when calling this kernel + @property + def call_name(self) -> str: + return f'rmsnorm2d_fwd_>' + + # string when define this kernel + @property + def def_name(self) -> str: + return f'template float rmsnorm2d_fwd_>(const S&, A);' + + # this class hold kernel under same source file + @dataclass + class h_instance: + F_DataTypePair : str + F_N : str + F_add : int + F_sweep : int + instance_list : List[Any] # List[h_traits] + + @property + def name(self) -> str: + prec_i, prec_o = self.F_DataTypePair.split(',') + dtype_str = f'{prec_i}' if prec_i == prec_o else f'{prec_i}_{prec_o}' + nnn = f'rmsnorm2d_fwd_{dtype_str}_n{self.F_N}' + if self.F_add != 0: + nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add] + if self.F_sweep != 0: + nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep] + return nnn + + @property + def instance_name(self) ->str: + return self.name + + @property + def content(self) ->str: + instance_defs = '' + for ins in self.instance_list: + instance_defs += ins.def_name + '\n' + return rmsnorm_fwd_codegen.INSTANCE_BASE.format(F_instance_def=instance_defs) + + @property + def name_api(self) -> str: + return 'rmsnorm2d_fwd_api' + + @property + def name_common_header(self) -> str: + return 'rmsnorm2d_fwd_api_common' + + @property + def content_api(self) -> str: + # 1 sort based on dtype + t_dtype_dict = dict() + blobs = self.get_blobs() + for blob in blobs: + if blob.F_DataTypePair not in t_dtype_dict: + t_dtype_dict[blob.F_DataTypePair] = {} + if blob.F_N not in t_dtype_dict[blob.F_DataTypePair]: + t_dtype_dict[blob.F_DataTypePair][blob.F_N] = [] + t_dtype_dict[blob.F_DataTypePair][blob.F_N].append(blob) + + d_str = '' + for i_d, dtype_ in enumerate(t_dtype_dict): + blob_per_t = t_dtype_dict[dtype_] + n_str = '' + for i_n, n_ in enumerate(blob_per_t): + blob_per_n = blob_per_t[n_] + inner_str = "" + for i_b, b_ in enumerate(blob_per_n): + # generate single kernel instance file + #vec_str = "" + for i_ins, ins in enumerate(b_.instance_list): + idx_in_n = i_b * len(b_.instance_list) + i_ins + len_in_n = len(blob_per_n) * len(b_.instance_list) + # _if = 'if' if i_ins == 0 else 'else if' + if ins.F_kFusedQuant == 0: + _sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant) + elif ins.F_kFusedQuant == 1: + _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sm == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format( + f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_SmoothScaleDataType, f_sy_type=ins.F_YScaleDataType) + elif ins.F_kFusedQuant == 2: + _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format( + f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType) + _cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format( + f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd, + f_sweep_cond = _sweep_cond) + inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False), + F_VEC_COND = _cond, F_instance_func=ins.call_name) + #inner_str = inner_str + vec_str + n_cnd = f'(a.n <= {n_})' if (i_n < len(blob_per_t) - 1) else '' + n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t)), F_N_COND=n_cnd, F_inner_dispatch=inner_str) + prec_i, prec_o = dtype_.split(',') + d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str) + + api_base = self.API_BASE.format(F_traits_define=self.API_TRAITS_DEFINE, F_dispatch=d_str) + return api_base + + @property + def content_common_header(self) -> str: + return self.API_COMMON_HEADER.format(F_traits_define=self.API_TRAITS_DEFINE) + + def get_blobs(self): + h_traits = rmsnorm_fwd_codegen.h_traits + h_instance = rmsnorm_fwd_codegen.h_instance + + dynamic_quant_out_dtype = ['int8', 'fp8'] + # some predefined support range + # (prec_i,prec_o) for simplicity this string will be used as key for dict + scale_list = [('fp32,fp32')] + dtype_list = [('fp16,fp16'), ('bf16,bf16'), + ('fp16,int8'), ('bf16,int8'), + ('fp16,fp8'), ('bf16,fp8')] # NOTE: only fused-dynamic-quant use int8 out + #fused_add_list = [0, 1, 2] + #fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant + fused_add_list = [0, 1] + fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant + + # rm rn tm tn vn pd mv 2p add sweep + h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, False, 0, 0)], + '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, False, 0, 0)], + '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, False, 0, 0)], + '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, False, 0, 0)], + '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, False, 0, 0)], + '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, False, 0, 0)], + '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, False, 0, 0)], + '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, False, 0, 0)], + '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, False, 0, 0)], + '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, False, 0, 0)], + '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, False, 0, 0)], + '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, False, 0, 0)], + 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0)]} + total_blob = list() + for hs_key in h_trait_dict: + hs = h_trait_dict[hs_key] + current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N + for dtype, scale_type, fused_add, fused_quant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list): + prec_i, prec_o = dtype.split(',') + scale_sm, scale_y = scale_type.split(',') + if prec_o in dynamic_quant_out_dtype and fused_quant != 1 and fused_quant != 2: + continue # skip non dynamic quant case + if (fused_quant == 1 or fused_quant == 2) and hs_key == 'big': + continue + current_hs = list() + for chs_ in hs: + h_ = copy.copy(chs_) # copy the base instance out + h_.F_XDataType = prec_i + h_.F_YDataType = prec_o + h_.F_SmoothScaleDataType = scale_sm + h_.F_YScaleDataType = scale_y + h_.F_kFusedAdd = fused_add + h_.F_kFusedQuant = fused_quant + current_hs.append(h_) # + "\n" + #f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_ + current_n_str = 'big' if hs_key == 'big' else current_n + total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, current_hs)) + return total_blob + + def list_blobs(self) -> None: + w_p = Path(self.working_path) + list_p = w_p / 'rmsnorm2d_fwd_blobs.txt' + blobs = self.get_blobs() + with list_p.open('w') as list_f: + # api related file + list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n") + list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n") + # kernel instance file + for b in blobs: + list_f.write(str(w_p / (b.name + ".cpp")) + "\n") + + def gen_blobs(self) -> None: + w_p = Path(self.working_path) + (w_p / (self.name_api + ".cpp")).write_text(self.content_api) + (w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header) + blobs = self.get_blobs() + for b in blobs: + (w_p / (b.name + ".cpp")).write_text(b.content) + + +def list_blobs(args): + api_list = args.api.split(',') + for api in api_list: + if api == 'fwd': + rmsnorm_fwd_codegen(args.working_path, args.filter).list_blobs() + + +def gen_blobs(args): + api_list = args.api.split(',') + for api in api_list: + if api == 'fwd': + rmsnorm_fwd_codegen(args.working_path, args.filter).gen_blobs() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate", + description="gen API for CK rmsnorm kernel", + ) + parser.add_argument( + "-a", + "--api", + default='fwd[all]', + required=False, + help="supply API(s) to generate (default: fwd). separated by comma." + ) + + # the directory for list_blobs/gen_blobs to write files into + parser.add_argument( + "-w", + "--working_path", + default="./", + required=False, + help="the path where all the blobs are going to be generated" + ) + + # this script have 2 modes + # 1) list_blobs mode, will generate a txt file with all the files going to be generated. + # this is useful in build system like cmake to construct source code dependency, by + # reading the content out of this file + # 2) gen_blobs mode, will generate the actuall kernel instance and api. If in framework + # like FA, only need to use this mode + parser.add_argument( + "-l", + "--list_blobs", + action='store_true', + help="list all the kernels to a file, " + ) + + parser.add_argument( + "-g", + "--gen_blobs", + action='store_true', + help="generate all kernels into different tile" + ) + + # TODO: if using filter, must apply same value to output_dir and list_blobs + parser.add_argument( + "-f", + "--filter", + required=False, + help="filter out kernels that need to generate, using fnmatch module" + ) + + parser.add_argument( + "-t", + "--traits", + default="all", + required=False, + help="enable/disable some feature. default generate all" + ) + + parser.add_argument( + "-r", + "--receipt", + default=0, + required=False, + help="codegen receipt." + ) + + args = parser.parse_args() + + # print(f'{args.list_blobs}-{args.gen_blobs}') + if (args.gen_blobs and args.list_blobs) or ((not args.gen_blobs) and (not args.list_blobs)): + print('gen_blobs/list_blobs must specify only one option') + sys.exit() + + p = Path(args.working_path) + if not p.exists(): + p.mkdir() + + if args.list_blobs: + list_blobs(args) + else: + gen_blobs(args) \ No newline at end of file diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp deleted file mode 100644 index b8697183f96bc6b4421ecbd8b26353ae7c00941e..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp +++ /dev/null @@ -1,146 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include "rmsnorm2d_fwd.hpp" - -template -using trait_ = rmsnorm2d_fwd_traits_; - -template -float rmsnorm2d_fwd_b16_(rmsnorm2d_fwd_traits /*t*/, - rmsnorm2d_fwd_args a, - const ck_tile::stream_config& s) -{ - float r = -1; - // clang-format off - // rm rn tm tn vn pd rms 2p - if(a.n <= 64) { - r = rmsnorm2d_fwd_>(s, a); - } - else if(a.n <= 128) { - if (a.n % 2 == 0) - r = rmsnorm2d_fwd_>(s, a); - else - r = rmsnorm2d_fwd_>(s, a); - } - else if(a.n <= 256) { - if (a.n % 4 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = rmsnorm2d_fwd_>(s, a); - else - r = rmsnorm2d_fwd_>(s, a); - } - else if(a.n <= 512) { - if (a.n % 8 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = rmsnorm2d_fwd_>(s, a); - else - r = rmsnorm2d_fwd_>(s, a); - } - else if(a.n <= 768) { - if (a.n % 4 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = rmsnorm2d_fwd_>(s, a); - else - r = rmsnorm2d_fwd_>(s, a); - } - else if(a.n <= 1024) { - if (a.n % 8 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = rmsnorm2d_fwd_>(s, a); - else - r = rmsnorm2d_fwd_>(s, a); - } - else if(a.n <= 1536) { - if (a.n % 8 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = rmsnorm2d_fwd_>(s, a); - else - r = rmsnorm2d_fwd_>(s, a); - } - else if(a.n <= 2048) { - if (a.n % 8 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = rmsnorm2d_fwd_>(s, a); - else - r = rmsnorm2d_fwd_>(s, a); - } - else if(a.n <= 3072) { - if (a.n % 8 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = rmsnorm2d_fwd_>(s, a); - else - r = rmsnorm2d_fwd_>(s, a); - } - else if(a.n <= 4096) { - if (a.n % 8 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = rmsnorm2d_fwd_>(s, a); - else - r = rmsnorm2d_fwd_>(s, a); - } - else if(a.n > 4096) { - if (a.n % 8 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = rmsnorm2d_fwd_>(s, a); - else - r = rmsnorm2d_fwd_>(s, a); - } - return r; - // clang-format on -} - -float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, rmsnorm2d_fwd_args a, const ck_tile::stream_config& s) -{ - - if(t.data_type.compare("fp16") == 0) - { - return rmsnorm2d_fwd_b16_(t, a, s); - } - else if(t.data_type.compare("bf16") == 0) - { - return rmsnorm2d_fwd_b16_(t, a, s); - } - else - throw std::runtime_error("Without supported instances!"); -} diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp deleted file mode 100644 index 5e2a35f9e8fb496f21832c222b9f68042a63a21d..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp +++ /dev/null @@ -1,22 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -#if 0 -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); - -template float rmsnorm2d_fwd_>(const S&, A); -#endif - -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp deleted file mode 100644 index 8c734806e18b4782092f7a0e5cc460b3abc158d4..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp +++ /dev/null @@ -1,13 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp deleted file mode 100644 index 9222001433464eebcf1e20911b6b06b85c117270..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n256_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n256_instance.cpp deleted file mode 100644 index ed33c849232cc95d251240f7d146678273bd4e52..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n256_instance.cpp +++ /dev/null @@ -1,12 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n3072_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n3072_instance.cpp deleted file mode 100644 index b753bbc3458d3194f0cc6962b51d499bd331848b..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n3072_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp deleted file mode 100644 index 27cb9bdf3d47dc34909e0c1333c5daa32e640274..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp deleted file mode 100644 index 23afb5672b4b109fa9d2b89abec5318766540e92..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp deleted file mode 100644 index b428f58051bae64bd0497a991a7fc265ad96ec3d..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp +++ /dev/null @@ -1,13 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp deleted file mode 100644 index 3001106697dafa0d521672af936e17b1cf2fddac..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp +++ /dev/null @@ -1,12 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp deleted file mode 100644 index e9c8d6a1d444b0b085bbfc20575352a1d766b2a3..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp +++ /dev/null @@ -1,12 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp deleted file mode 100644 index 15198eebe67258266529ed81a7c8f5bf16d48ca2..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp +++ /dev/null @@ -1,22 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -#if 0 -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); - -template float rmsnorm2d_fwd_>(const S&, A); -#endif - -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp deleted file mode 100644 index 8ac85fa9b5a68246a5af7c039b4131a3b35c9c56..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp +++ /dev/null @@ -1,13 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp deleted file mode 100644 index 10e8fafc2f4c780ff622ee501f685671fc7dd25a..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp deleted file mode 100644 index 4e1a80bf64b3598864765454fd46f8ce9c9c6eb0..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp +++ /dev/null @@ -1,12 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp deleted file mode 100644 index 45e56a92b8886ffe3b07189646aad12caeffb359..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp deleted file mode 100644 index 35401f6f82b50c40137599456250c13c46092cb2..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp deleted file mode 100644 index 1e3700fad3ab61ff669d2950f96578170a01320e..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp deleted file mode 100644 index cdc4d00bd2336a1e55c900cd078dd8cde52ac11b..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp +++ /dev/null @@ -1,13 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp deleted file mode 100644 index ec80c2ee4a93f999be3960ba7154a16d8992f302..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp +++ /dev/null @@ -1,12 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp deleted file mode 100644 index ddfc5a54e8e6ea3129804f28a56aed98b5432f67..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp +++ /dev/null @@ -1,12 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp deleted file mode 100644 index 8f6ff84b643d2b7fafebc5b0a9ef6ade1ebdbd23..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp +++ /dev/null @@ -1,65 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include "rmsnorm2d_fwd.hpp" -#include - -#pragma once - -using S = ck_tile::stream_config; -using A = rmsnorm2d_fwd_args; - -template -using trait_ = rmsnorm2d_fwd_traits_; - -template -float rmsnorm2d_fwd_(const S& s, A a) -{ - using DataType = typename Traits_::DataType; - - using PipelineProblem = - ck_tile::Rmsnorm2dFwdPipelineProblem::XDataType, - typename RmsnormTypeConfig::GammaDataType, - typename RmsnormTypeConfig::ComputeDataType, - typename RmsnormTypeConfig::YDataType, - typename RmsnormTypeConfig::InvRmsDataType, - typename Traits_::Shape, - Traits_::kPadN, - Traits_::kSaveInvRms, - Traits_::kTwoPass>; - - using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass; - using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass; - using Pipeline = std::conditional_t; - - using Kernel = ck_tile::Rmsnorm2dFwd; - - const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = 1; - - auto kargs = Kernel::MakeKargs(a); - if(s.log_level_ > 0) - std::cout << ", " << Kernel::GetName() << std::flush; - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); -} diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp index 698a8b43eb9329f5bfb0c61b78cea98a0cf07f5f..cdee6dfb80041ef3afa8e5545313b1bab472399b 100644 --- a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp @@ -19,17 +19,37 @@ auto get_elimit() return ck_tile::make_tuple(rtol, atol); } +template <> +auto get_elimit() +{ + double rtol = 1e-02; + double atol = 1.0; + return ck_tile::make_tuple(rtol, atol); +} + auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; arg_parser.insert("m", "3328", "m dimension") .insert("n", "4096", "n dimension") - .insert("stride", "-1", "stride per row, if -1 then equal to n") + .insert("x_stride", "-1", "x row_stride, if -1 then equal to n") + .insert("xr_stride", "-1", "x residule row_stride, if -1 then equal to n") + .insert("y_stride", "-1", "y row_stride, if -1 then equal to n") + .insert("yr_stride", "-1", "y residule row_stride, if -1 then equal to n") .insert("e", "1e-5", "epsilon") .insert("save_rms", "0", "save rms(invrms) or not. set to 1 in training case") .insert("v", "1", "cpu validation or not") .insert("kname", "1", "print kernel name or not") - .insert("prec", "fp16", "precision") + .insert("prec_i", "fp16", "input precision") + .insert("prec_o", "auto", "output precision, set auto will be the same as input") + .insert("prec_sm", + "auto", + "output quant scale type, set auto will use fp32. used when fquant=1") + .insert("prec_sy", + "auto", + "output quant scale type, set auto will use fp32. used when fquant=1 or 2") + .insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only") + .insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant") .insert("warmup", "5", "cold iter") .insert("repeat", "20", "hot iter"); @@ -37,28 +57,70 @@ auto create_args(int argc, char* argv[]) return std::make_tuple(result, arg_parser); } -template +template bool run(const ck_tile::ArgParser& arg_parser) { - ck_tile::index_t m = arg_parser.get_int("m"); - ck_tile::index_t n = arg_parser.get_int("n"); - ck_tile::index_t stride = arg_parser.get_int("stride"); - if(stride < 0) - stride = n; - float epsilon = arg_parser.get_float("e"); - std::string data_type = arg_parser.get_str("prec"); - int kname = arg_parser.get_int("kname"); - int do_validation = arg_parser.get_int("v"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); - - assert(stride >= n); - - using TypeConfig = RmsnormTypeConfig; - - using XDataType = typename TypeConfig::XDataType; - using YDataType = typename TypeConfig::YDataType; - using GammaDataType = typename TypeConfig::GammaDataType; + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + float epsilon = arg_parser.get_float("e"); + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int fused_add = arg_parser.get_int("fadd"); + int fused_quant = arg_parser.get_int("fquant"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + ck_tile::index_t x_stride = arg_parser.get_int("x_stride"); + if(x_stride < 0) + x_stride = n; + ck_tile::index_t xr_stride = arg_parser.get_int("xr_stride"); + if(xr_stride < 0) + xr_stride = n; + ck_tile::index_t y_stride = arg_parser.get_int("y_stride"); + if(y_stride < 0) + y_stride = n; + ck_tile::index_t yr_stride = arg_parser.get_int("yr_stride"); + if(yr_stride < 0) + yr_stride = n; + assert(x_stride >= n); + + std::string prec_i = arg_parser.get_str("prec_i"); + std::string prec_o = arg_parser.get_str("prec_o"); + std::string prec_sm = arg_parser.get_str("prec_sm"); + std::string prec_sy = arg_parser.get_str("prec_sy"); + if(prec_o == "auto") + { + prec_o = prec_i; + } + if(prec_sm == "auto") + { + prec_sm = "fp32"; + } + if(prec_sy == "auto") + { + prec_sy = "fp32"; + } + + if((fused_quant == 1 || fused_quant == 2) && prec_o != "int8" && prec_o != "fp8") + { + std::cout + << "if fused_quant is 1 or 2, only support \"-prec_o=int8\" or \"-prec_o=fp8\" cases." + << std::endl; + return false; + } + + using TypeConfig = + RmsnormTypeConfig; + + using XDataType = typename TypeConfig::XDataType; + using YDataType = typename TypeConfig::YDataType; + using GammaDataType = typename TypeConfig::GammaDataType; + using XResidualDataType = XDataType; + using YResidualDataType = XDataType; using InvRmsDataType = std::conditional_t; @@ -66,43 +128,84 @@ bool run(const ck_tile::ArgParser& arg_parser) using ComputeDataType = typename TypeConfig::ComputeDataType; // host verify - ck_tile::HostTensor x_host({m, n}, {stride, 1}); + ck_tile::HostTensor x_host({m, n}, {x_stride, 1}); ck_tile::HostTensor gamma_host({n}); + ck_tile::HostTensor sm_scale_host({n}); + ck_tile::HostTensor sm_scale_host_dev({n}); - ck_tile::HostTensor y_host_ref({m, n}, {stride, 1}); - ck_tile::HostTensor y_host_dev({m, n}, {stride, 1}); + ck_tile::HostTensor x_residual_host({m, n}, {xr_stride, 1}); + ck_tile::HostTensor y_residual_host({m, n}, {yr_stride, 1}); + + ck_tile::HostTensor y_host_ref({m, n}, {y_stride, 1}); + ck_tile::HostTensor y_host_dev({m, n}, {y_stride, 1}); + ck_tile::HostTensor y_scale_host_ref({m}); + ck_tile::HostTensor y_scale_host_dev({m}); ck_tile::HostTensor invRms_host_ref({m}); ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(x_residual_host); + ck_tile::FillUniformDistribution{-1.f, 1.f}(sm_scale_host); ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sm_scale_buf(sm_scale_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes()); x_buf.ToDevice(x_host.data()); gamma_buf.ToDevice(gamma_host.data()); + x_residual_buf.ToDevice(x_residual_host.data()); + sm_scale_buf.ToDevice(sm_scale_host.data()); + + auto prec_str = [&]() { + auto base_str = prec_i; + if(prec_i != prec_o) + { + base_str += "|" + prec_o; + } + if(fused_quant == 1) + { + base_str += std::string("(") + prec_sy + ")"; + } + return base_str; + }(); - std::cout << "[" << data_type << "]" - << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; + std::cout << "[" << prec_str << "]" + << " m:" << m << ", n:" << n << ", x_stride:" << x_stride + << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride + << ", yr_stride:" << yr_stride << std::flush; - rmsnorm2d_fwd_traits traits{data_type, SaveRms}; + rmsnorm2d_fwd_traits traits{prec_i, prec_o, prec_sm, prec_sy, SaveRms, fused_add, fused_quant}; rmsnorm2d_fwd_args args{x_buf.GetDeviceBuffer(), + fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr, + fused_quant == 1 ? sm_scale_buf.GetDeviceBuffer() : nullptr, gamma_buf.GetDeviceBuffer(), y_buf.GetDeviceBuffer(), - nullptr, + fused_add == 1 ? y_residual_buf.GetDeviceBuffer() : nullptr, + fused_quant != 0 ? y_scale_buf.GetDeviceBuffer() : nullptr, + nullptr, // p_invRms, unsupported yet epsilon, m, n, - stride}; + x_stride, // x row_stride + xr_stride, // x residule row stride + y_stride, // y row stride + yr_stride}; // y residule row stride float ave_time = rmsnorm2d_fwd( traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + sizeof(YDataType) * m * n; + num_byte += SaveRms ? sizeof(InvRmsDataType) * m * n : 0; + num_byte += fused_add ? sizeof(XResidualDataType) * m * n : 0; + num_byte += ((fused_quant == 1) || (fused_quant == 2)) ? sizeof(YScaleDataType) * m : 0; + num_byte += (fused_quant == 1) ? sizeof(SmoothScaleDataType) * n : 0; float gb_per_sec = num_byte / 1.E6 / ave_time; std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; @@ -112,38 +215,135 @@ bool run(const ck_tile::ArgParser& arg_parser) if(do_validation) { // reference - ck_tile::reference_rmsnorm2d_fwd( - x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon); + if(fused_add != 0) + { + // fused pre_add/pre_add_store + // TODO we accumulate directly to x_host for simplcity here... + std::transform(x_host.mData.cbegin(), + x_host.mData.cend(), + x_residual_host.mData.cbegin(), + x_host.mData.begin(), + [](auto x_, auto r_) { + auto o_ = ck_tile::type_convert(x_) + + ck_tile::type_convert(r_); + return ck_tile::type_convert(o_); + }); + } + + if(fused_quant != 0) + { + auto dquant_functor = [&](int m_, auto& o_, auto& acc_) { + int N_ = acc_.mDesc.get_lengths()[1]; + if(fused_quant == 1) + { + for(int n_ = 0; n_ < N_; n_++) + { + // input smooth outlier + acc_(m_, n_) = acc_(m_, n_) * + ck_tile::type_convert(sm_scale_host(n_)); + } + } + ComputeDataType absmax = static_cast(0); + for(int n_ = 0; n_ < N_; n_++) + { + const auto a = ck_tile::abs(acc_(m_, n_)); + absmax = a > absmax ? a : absmax; + } + // printf("cpu:absmax:%f\n", absmax); + constexpr ComputeDataType kMaxY = + std::is_same::value ? 240.0 + : std::is_same::value ? 127.0 + : 0.0; + ComputeDataType y_scale = absmax / kMaxY; + y_scale_host_ref(m_) = ck_tile::type_convert(y_scale); + for(int n_ = 0; n_ < N_; n_++) + { + o_(m_, n_) = ck_tile::type_convert(acc_(m_, n_) / y_scale); + } + }; + + ck_tile::reference_rmsnorm2d_fwd( + x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon, dquant_functor); + } + else + { + ck_tile::reference_rmsnorm2d_fwd( + x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon); + } y_buf.FromDevice(y_host_dev.data()); - auto [rtol, atol] = get_elimit(); - if(stride == n) + ck_tile::HostTensor y_residual_host_dev({m, n}, {yr_stride, 1}); + if(fused_add == 1) + { + y_residual_buf.FromDevice(y_residual_host_dev.data()); + } + + auto [rtol, atol] = get_elimit(); + if(x_stride == n) { pass = ck_tile::check_err( - y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); + y_host_dev, y_host_ref, std::string("\nOUT Error: Incorrect results!"), rtol, atol); + + if(fused_add == 1) + { + pass &= ck_tile::check_err(y_residual_host_dev, + x_host, + std::string("\nADD Error: Incorrect results!"), + rtol, + atol); + } } else { for(int i_r = 0; i_r < m; i_r++) { - std::vector y_host_dev_row(y_host_dev.begin() + i_r * stride, - y_host_dev.begin() + i_r * stride + n); - std::vector y_host_ref_row(y_host_ref.begin() + i_r * stride, - y_host_ref.begin() + i_r * stride + n); + std::vector y_host_dev_row(y_host_dev.begin() + i_r * y_stride, + y_host_dev.begin() + i_r * y_stride + n); + std::vector y_host_ref_row(y_host_ref.begin() + i_r * y_stride, + y_host_ref.begin() + i_r * y_stride + n); pass &= ck_tile::check_err(y_host_dev_row, y_host_ref_row, - std::string("OUT[") + std::to_string(i_r) + + std::string("\nOUT[") + std::to_string(i_r) + std::string("] Error: Incorrect results!"), rtol, atol); + + if(fused_add == 1) + { + std::vector y_residual_host_dev_row( + y_residual_host_dev.begin() + i_r * yr_stride, + y_residual_host_dev.begin() + i_r * yr_stride + n); + std::vector y_residual_host_ref_row( + x_host.begin() + i_r * yr_stride, x_host.begin() + i_r * yr_stride + n); + pass &= ck_tile::check_err(y_residual_host_dev_row, + y_residual_host_ref_row, + std::string("\nADD[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } } } + if(fused_quant == 1) + { + y_scale_buf.FromDevice(y_scale_host_dev.data()); + pass &= ck_tile::check_err(y_scale_host_dev, + y_scale_host_ref, + std::string("\nSCALE Error: Incorrect results!"), + rtol, + atol); + } + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } @@ -156,23 +356,65 @@ int main(int argc, char* argv[]) if(!result) return -1; - const std::string data_type = arg_parser.get_str("prec"); - int save_rms = arg_parser.get_int("save_rms"); - if(data_type == "fp16" && save_rms) + std::string prec_i = arg_parser.get_str("prec_i"); + std::string prec_o = arg_parser.get_str("prec_o"); + std::string prec_sm = arg_parser.get_str("prec_sm"); + std::string prec_sy = arg_parser.get_str("prec_sy"); + if(prec_o == "auto") + { + prec_o = prec_i; + } + if(prec_sm == "auto") + { + prec_sm = "fp32"; + } + if(prec_sy == "auto") + { + prec_sy = "fp32"; + } + + int save_rms = arg_parser.get_int("save_rms"); + + if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" && save_rms) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_rms) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" && + save_rms) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_rms) + { + return run(arg_parser) ? 0 : -2; + } + + // dynamic quant case, only in inference + else if(prec_i == "fp16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_rms) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } - else if(data_type == "fp16" && !save_rms) + else if(prec_i == "bf16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_rms) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } - else if(data_type == "bf16" && save_rms) + else if(prec_i == "fp16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_rms) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } - else if(data_type == "bf16" && !save_rms) + else if(prec_i == "bf16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_rms) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } return -3; diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp index b4d429d46f4a1a418527bf515ccd5e06e6243352..566b94442d4c934cc63375e18a5a93eb6b55841c 100644 --- a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -8,27 +8,34 @@ #include "ck_tile/ops/rmsnorm2d.hpp" #include -template +template struct RmsnormTypeConfig; -template <> -struct RmsnormTypeConfig +template +struct RmsnormTypeConfig { - using XDataType = ck_tile::half_t; - using YDataType = ck_tile::half_t; - using GammaDataType = ck_tile::half_t; - using InvRmsDataType = ck_tile::half_t; - using ComputeDataType = float; + using XDataType = ck_tile::half_t; + using YDataType = OutType; + using GammaDataType = ck_tile::half_t; + using InvRmsDataType = ck_tile::half_t; + using ComputeDataType = float; + using SmoothScaleDataType = SmoothScaleDataType_; + using YScaleDataType = YScaleDataType_; }; -template <> -struct RmsnormTypeConfig +template +struct RmsnormTypeConfig { - using XDataType = ck_tile::bf16_t; - using YDataType = ck_tile::bf16_t; - using GammaDataType = ck_tile::bf16_t; - using InvRmsDataType = ck_tile::bf16_t; - using ComputeDataType = float; + using XDataType = ck_tile::bf16_t; + using YDataType = OutType; + using GammaDataType = ck_tile::bf16_t; + using InvRmsDataType = ck_tile::bf16_t; + using ComputeDataType = float; + using SmoothScaleDataType = SmoothScaleDataType_; + using YScaleDataType = YScaleDataType_; }; // runtime args @@ -36,82 +43,24 @@ struct rmsnorm2d_fwd_args : public ck_tile::Rmsnorm2dFwdHostArgs { }; -// this is used to pattern-match internl kernel implementation, not to instantiate kernel -template -struct rmsnorm2d_fwd_traits_ -{ - using DataType = ck_tile::remove_cvref_t; - - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); - static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; - - // num of warps along m - static constexpr ck_tile::index_t BlockWarps_M = []() { - if constexpr(is_warp_per_row) - { - static_assert(warpSize % ThreadPerBlock_N_ == 0); - return total_warps * (warpSize / ThreadPerBlock_N_); - } - else - { - // static_assert(warpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / warpSize); - } - }(); - - // num of warps along n - static constexpr ck_tile::index_t BlockWarps_N = []() { - if constexpr(is_warp_per_row) - { - static_assert(warpSize % ThreadPerBlock_N_ == 0); - return 1; - } - else - { - static_assert(ThreadPerBlock_N_ % warpSize == 0); - return ThreadPerBlock_N_ / warpSize; - } - }(); - - static constexpr ck_tile::index_t Repeat_M = Repeat_M_; - static constexpr ck_tile::index_t Repeat_N = Repeat_N_; - - static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; - static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; - - static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; - static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; - - using BlockTile = ck_tile::sequence; - using BlockWarps = ck_tile::sequence; - using WarpTile = ck_tile::sequence; - using Vector = ck_tile::sequence<1, Vector_N_>; - - using Shape = ck_tile::Generic2dBlockShape; - - static constexpr bool kPadN = kPadN_; - static constexpr bool kSaveInvRms = kSaveInvRms_; - static constexpr bool kTwoPass = kTwoPass_; -}; - template float rmsnorm2d_fwd_(const ck_tile::stream_config& s, rmsnorm2d_fwd_args a); // This is the public API, will be generated by script struct rmsnorm2d_fwd_traits { - std::string data_type; + std::string prec_i; // input precision + std::string prec_o; // output precision + + // if fused_quant == 1, need set prec_sm/prec_sy to proper string, otherwise can set + // arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise + // can set arbitrary(will skip check) + std::string prec_sm; // x-scale, used for [1*N] input smooth quant + std::string prec_sy; // y-scale, used for [M*1] output for next layer + bool save_rms; + int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add + int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant }; float rmsnorm2d_fwd(rmsnorm2d_fwd_traits, rmsnorm2d_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh b/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh index 758d6de54680cc303068044c5e0fc8d27baba4b2..ab890738b31aff69169d34d6af3214d7f2b63dbc 100755 --- a/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh +++ b/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh @@ -1,30 +1,34 @@ #!/bin/sh EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)" +for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -prec_o=fp8" "-fquant=2 -prec_o=fp8"; do for pr_i in "fp16" "bf16" ; do -$EXE -prec=$pr_i -m=99 -n=13 -$EXE -prec=$pr_i -m=17 -n=16 -$EXE -prec=$pr_i -m=1 -n=100 -$EXE -prec=$pr_i -m=4 -n=128 -$EXE -prec=$pr_i -m=80 -n=127 -$EXE -prec=$pr_i -m=22 -n=255 -stride=256 -$EXE -prec=$pr_i -m=7 -n=599 -$EXE -prec=$pr_i -m=19 -n=512 -$EXE -prec=$pr_i -m=33 -n=313 -stride=1000 -$EXE -prec=$pr_i -m=11 -n=510 -$EXE -prec=$pr_i -m=171 -n=676 -stride=818 -$EXE -prec=$pr_i -m=91 -n=636 -$EXE -prec=$pr_i -m=12 -n=768 -stride=800 -$EXE -prec=$pr_i -m=100 -n=766 -stride=812 -$EXE -prec=$pr_i -m=31 -n=1024 -$EXE -prec=$pr_i -m=64 -n=1000 -stride=1004 -$EXE -prec=$pr_i -m=8 -n=1501 -$EXE -prec=$pr_i -m=3 -n=1826 -$EXE -prec=$pr_i -m=5 -n=2040 -$EXE -prec=$pr_i -m=7 -n=2734 -$EXE -prec=$pr_i -m=1 -n=3182 -$EXE -prec=$pr_i -m=9 -n=4096 -$EXE -prec=$pr_i -m=3 -n=8192 -$EXE -prec=$pr_i -m=1 -n=10547 -$EXE -prec=$pr_i -m=3 -n=17134 +for fadd in "0" "1"; do +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=17 -n=16 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=100 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=4 -n=128 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=80 -n=127 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=22 -n=255 -stride=256 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=599 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=19 -n=512 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=33 -n=313 -stride=1000 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=11 -n=510 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=171 -n=676 -stride=818 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=91 -n=636 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=12 -n=768 -stride=800 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=100 -n=766 -stride=812 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=31 -n=1024 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=64 -n=1000 -stride=1004 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=8 -n=1501 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=1826 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=5 -n=2040 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547 +#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134 +done +done done diff --git a/example/ck_tile/12_smoothquant/CMakeLists.txt b/example/ck_tile/12_smoothquant/CMakeLists.txt index 09a56c6dabf6fb158ec5108e210a6b99e5b104db..3849833aca2aebd3cbcdc8a8c672961ba01dfc5a 100644 --- a/example/ck_tile/12_smoothquant/CMakeLists.txt +++ b/example/ck_tile/12_smoothquant/CMakeLists.txt @@ -18,7 +18,7 @@ function (add_smoothquant_example TARGET_NAME MAIN_SRC) target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS}) endfunction(add_smoothquant_example TARGET_NAME MAIN_SRC) -file(GLOB INSTANCE_SRCS instances/*.cpp) -add_smoothquant_example(tile_smoothquant smoothquant.cpp ${INSTANCE_SRCS}) add_smoothquant_example(tile_example_smoothquant example_smoothquant.cpp) +file(GLOB INSTANCE_SRCS instances/*.cpp) +add_smoothquant_example(tile_smoothquant smoothquant.cpp ${INSTANCE_SRCS}) diff --git a/example/ck_tile/12_smoothquant/example_smoothquant.cpp b/example/ck_tile/12_smoothquant/example_smoothquant.cpp index 3a26eb6a77904f47a810f53271b0b3657fc474ee..20e1591516f4b614f023d2c53ac6f5fd8c3d5110 100644 --- a/example/ck_tile/12_smoothquant/example_smoothquant.cpp +++ b/example/ck_tile/12_smoothquant/example_smoothquant.cpp @@ -35,7 +35,8 @@ auto create_args(int argc, char* argv[]) ck_tile::ArgParser arg_parser; arg_parser.insert("m", "3328", "m dimension") .insert("n", "4096", "n dimension") - .insert("stride", "-1", "stride per row, if -1 then equal to n") + .insert("x_stride", "-1", "input stride per row, if -1 then equal to n") + .insert("y_stride", "-1", "output stride per row, if -1 then equal to n") .insert("e", "1e-5", "epsilon") .insert("v", "1", "cpu validation or not") .insert("prec", "fp16", "precision") @@ -49,44 +50,47 @@ auto create_args(int argc, char* argv[]) template bool run(const ck_tile::ArgParser& arg_parser) { - ck_tile::index_t m = arg_parser.get_int("m"); - ck_tile::index_t n = arg_parser.get_int("n"); - ck_tile::index_t stride = arg_parser.get_int("stride"); - if(stride < 0) - stride = n; + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t x_stride = arg_parser.get_int("x_stride"); + if(x_stride < 0) + x_stride = n; + ck_tile::index_t y_stride = arg_parser.get_int("y_stride"); + if(y_stride < 0) + y_stride = n; std::string data_type = arg_parser.get_str("prec"); int do_validation = arg_parser.get_int("v"); int warmup = arg_parser.get_int("warmup"); int repeat = arg_parser.get_int("repeat"); - assert(stride >= n); + assert(x_stride >= n); - using XDataType = DataType; - using XScaleDataType = float; - using YScaleDataType = float; - using QYDataType = ck_tile::int8_t; - using ComputeDataType = float; + using XDataType = DataType; + using SmoothScaleDataType = float; + using YScaleDataType = float; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; // host verify - ck_tile::HostTensor x_host({m, n}, {stride, 1}); - ck_tile::HostTensor xscale_host({n}); + ck_tile::HostTensor x_host({m, n}, {x_stride, 1}); + ck_tile::HostTensor smscale_host({n}); ck_tile::HostTensor yscale_host_ref({m}, {1}); ck_tile::HostTensor yscale_host_dev({m}, {1}); - ck_tile::HostTensor qy_host_ref({m, n}, {stride, 1}); - ck_tile::HostTensor qy_host_dev({m, n}, {stride, 1}); + ck_tile::HostTensor qy_host_ref({m, n}, {y_stride, 1}); + ck_tile::HostTensor qy_host_dev({m, n}, {y_stride, 1}); ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); - ck_tile::FillUniformDistribution{1e-3, .5f}(xscale_host); + ck_tile::FillUniformDistribution{1e-3, .5f}(smscale_host); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); x_buf.ToDevice(x_host.data()); - xscale_buf.ToDevice(xscale_host.data()); + smscale_buf.ToDevice(smscale_host.data()); constexpr bool kTwoPass = true; @@ -97,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser) using Shape = ck_tile::Generic2dBlockShape; using Problem = ck_tile::SmoothquantPipelineProblem; ck_tile::SmoothquantHostArgs args{x_buf.GetDeviceBuffer(), - xscale_buf.GetDeviceBuffer(), + smscale_buf.GetDeviceBuffer(), yscale_buf.GetDeviceBuffer(), qy_buf.GetDeviceBuffer(), m, n, - stride}; + x_stride, + y_stride}; auto kargs = Kernel::MakeKargs(args); @@ -133,20 +138,20 @@ bool run(const ck_tile::ArgParser& arg_parser) if(do_validation) { using YDataType = ComputeDataType; - ck_tile::HostTensor y_host({m, n}, {stride, 1}); + ck_tile::HostTensor y_host({m, n}, {y_stride, 1}); // smooth outlier { auto f = [&](auto n_) { - auto v_xscale = ck_tile::type_convert(xscale_host(n_)); + auto v_smscale = ck_tile::type_convert(smscale_host(n_)); for(int m_ = 0; m_ < m; ++m_) { auto v_x = ck_tile::type_convert(x_host(m_, n_)); - y_host(m_, n_) = v_x * v_xscale; + y_host(m_, n_) = v_x * v_smscale; } }; - ck_tile::make_ParallelTensorFunctor(f, xscale_host.get_element_space_size())( + ck_tile::make_ParallelTensorFunctor(f, smscale_host.get_element_space_size())( std::thread::hardware_concurrency()); } @@ -183,7 +188,7 @@ bool run(const ck_tile::ArgParser& arg_parser) qy_buf.FromDevice(qy_host_dev.data()); auto [rtol, atol] = get_elimit(); - if(stride == n) + if(y_stride == n) { pass = ck_tile::check_err(qy_host_dev, qy_host_ref, @@ -195,10 +200,12 @@ bool run(const ck_tile::ArgParser& arg_parser) { for(int i_r = 0; i_r < m; i_r++) { - std::vector qy_host_dev_row(qy_host_dev.begin() + i_r * stride, - qy_host_dev.begin() + i_r * stride + n); - std::vector qy_host_ref_row(qy_host_ref.begin() + i_r * stride, - qy_host_ref.begin() + i_r * stride + n); + std::vector qy_host_dev_row(qy_host_dev.begin() + i_r * y_stride, + qy_host_dev.begin() + i_r * y_stride + + n); + std::vector qy_host_ref_row(qy_host_ref.begin() + i_r * y_stride, + qy_host_ref.begin() + i_r * y_stride + + n); pass &= ck_tile::check_err(qy_host_dev_row, qy_host_ref_row, std::string("qy[") + std::to_string(i_r) + @@ -210,8 +217,9 @@ bool run(const ck_tile::ArgParser& arg_parser) } std::cout << "[" << data_type << "]" - << " m:" << m << ", n:" << n << ", stride:" << stride - << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + << " m:" << m << ", n:" << n << ", x_stride:" << x_stride + << ", y_stride:" << y_stride << ", valid:" << (pass ? "y" : "n") << std::flush + << std::endl; } return pass; diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp b/example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp index cdf93f6fcfd3d523c7723e28e006d089bf49f350..555159566eed1fb2e5b341bb9941aeba8cc9a1a4 100644 --- a/example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include "smoothquant.hpp" @@ -35,7 +35,7 @@ float smoothquant_(const S& s, A a) using PipelineProblem = ck_tile::SmoothquantPipelineProblem< typename SmoothquantTypeConfig::XDataType, - typename SmoothquantTypeConfig::XScaleDataType, + typename SmoothquantTypeConfig::SmoothScaleDataType, typename SmoothquantTypeConfig::ComputeDataType, typename SmoothquantTypeConfig::YScaleDataType, typename SmoothquantTypeConfig::QYDataType, diff --git a/example/ck_tile/12_smoothquant/smoothquant.cpp b/example/ck_tile/12_smoothquant/smoothquant.cpp index ed01d654fda239d87f34a6131d4c4d9744f9c859..f3ba587132fe4d8b7cb3d08f154c237b2ed53206 100644 --- a/example/ck_tile/12_smoothquant/smoothquant.cpp +++ b/example/ck_tile/12_smoothquant/smoothquant.cpp @@ -33,7 +33,8 @@ auto create_args(int argc, char* argv[]) ck_tile::ArgParser arg_parser; arg_parser.insert("m", "3328", "m dimension") .insert("n", "4096", "n dimension") - .insert("stride", "-1", "stride per row, if -1 then equal to n") + .insert("x_stride", "-1", "input stride per row, if -1 then equal to n") + .insert("y_stride", "-1", "output stride per row, if -1 then equal to n") .insert("v", "1", "cpu validation or not") .insert("kname", "1", "print kernel name or not") .insert("prec", "fp16", "precision") @@ -47,65 +48,70 @@ auto create_args(int argc, char* argv[]) template bool run(const ck_tile::ArgParser& arg_parser) { - ck_tile::index_t m = arg_parser.get_int("m"); - ck_tile::index_t n = arg_parser.get_int("n"); - ck_tile::index_t stride = arg_parser.get_int("stride"); - if(stride < 0) - stride = n; + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t x_stride = arg_parser.get_int("x_stride"); + if(x_stride < 0) + x_stride = n; + ck_tile::index_t y_stride = arg_parser.get_int("y_stride"); + if(y_stride < 0) + y_stride = n; std::string data_type = arg_parser.get_str("prec"); int kname = arg_parser.get_int("kname"); int do_validation = arg_parser.get_int("v"); int warmup = arg_parser.get_int("warmup"); int repeat = arg_parser.get_int("repeat"); - assert(stride >= n); + assert(x_stride >= n); using TypeConfig = SmoothquantTypeConfig; - using XDataType = typename TypeConfig::XDataType; - using XScaleDataType = typename TypeConfig::XScaleDataType; - using YScaleDataType = typename TypeConfig::YScaleDataType; - using QYDataType = typename TypeConfig::QYDataType; - using ComputeDataType = typename TypeConfig::ComputeDataType; + using XDataType = typename TypeConfig::XDataType; + using SmoothScaleDataType = typename TypeConfig::SmoothScaleDataType; + using YScaleDataType = typename TypeConfig::YScaleDataType; + using QYDataType = typename TypeConfig::QYDataType; + using ComputeDataType = typename TypeConfig::ComputeDataType; // host verify - ck_tile::HostTensor x_host({m, n}, {stride, 1}); - ck_tile::HostTensor xscale_host({n}); + ck_tile::HostTensor x_host({m, n}, {x_stride, 1}); + ck_tile::HostTensor smscale_host({n}); ck_tile::HostTensor yscale_host_ref({m}, {1}); ck_tile::HostTensor yscale_host_dev({m}, {1}); - ck_tile::HostTensor qy_host_ref({m, n}, {stride, 1}); - ck_tile::HostTensor qy_host_dev({m, n}, {stride, 1}); + ck_tile::HostTensor qy_host_ref({m, n}, {y_stride, 1}); + ck_tile::HostTensor qy_host_dev({m, n}, {y_stride, 1}); ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); - ck_tile::FillUniformDistribution{1e-3, .5f}(xscale_host); + ck_tile::FillUniformDistribution{1e-3, .5f}(smscale_host); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); x_buf.ToDevice(x_host.data()); - xscale_buf.ToDevice(xscale_host.data()); + smscale_buf.ToDevice(smscale_host.data()); std::cout << "[" << data_type << "]" - << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; + << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride + << std::flush; smoothquant_traits traits{data_type}; smoothquant_args args{x_buf.GetDeviceBuffer(), - xscale_buf.GetDeviceBuffer(), + smscale_buf.GetDeviceBuffer(), yscale_buf.GetDeviceBuffer(), qy_buf.GetDeviceBuffer(), m, n, - stride}; + x_stride, + y_stride}; float ave_time = smoothquant( traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); - std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(XScaleDataType) * n + + std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(SmoothScaleDataType) * n + sizeof(YScaleDataType) * m + sizeof(QYDataType) * m * n; float gb_per_sec = num_byte / 1.E6 / ave_time; @@ -116,20 +122,20 @@ bool run(const ck_tile::ArgParser& arg_parser) if(do_validation) { using YDataType = ComputeDataType; - ck_tile::HostTensor y_host({m, n}, {stride, 1}); + ck_tile::HostTensor y_host({m, n}, {y_stride, 1}); // smooth outlier { auto f = [&](auto n_) { - auto v_xscale = ck_tile::type_convert(xscale_host(n_)); + auto v_smscale = ck_tile::type_convert(smscale_host(n_)); for(int m_ = 0; m_ < m; ++m_) { auto v_x = ck_tile::type_convert(x_host(m_, n_)); - y_host(m_, n_) = v_x * v_xscale; + y_host(m_, n_) = v_x * v_smscale; } }; - ck_tile::make_ParallelTensorFunctor(f, xscale_host.get_element_space_size())( + ck_tile::make_ParallelTensorFunctor(f, smscale_host.get_element_space_size())( std::thread::hardware_concurrency()); } @@ -166,7 +172,7 @@ bool run(const ck_tile::ArgParser& arg_parser) qy_buf.FromDevice(qy_host_dev.data()); auto [rtol, atol] = get_elimit(); - if(stride == n) + if(y_stride == n) { pass = ck_tile::check_err(qy_host_dev, qy_host_ref, @@ -178,10 +184,12 @@ bool run(const ck_tile::ArgParser& arg_parser) { for(int i_r = 0; i_r < m; i_r++) { - std::vector qy_host_dev_row(qy_host_dev.begin() + i_r * stride, - qy_host_dev.begin() + i_r * stride + n); - std::vector qy_host_ref_row(qy_host_ref.begin() + i_r * stride, - qy_host_ref.begin() + i_r * stride + n); + std::vector qy_host_dev_row(qy_host_dev.begin() + i_r * y_stride, + qy_host_dev.begin() + i_r * y_stride + + n); + std::vector qy_host_ref_row(qy_host_ref.begin() + i_r * y_stride, + qy_host_ref.begin() + i_r * y_stride + + n); pass &= ck_tile::check_err(qy_host_dev_row, qy_host_ref_row, std::string("qy[") + std::to_string(i_r) + diff --git a/example/ck_tile/12_smoothquant/smoothquant.hpp b/example/ck_tile/12_smoothquant/smoothquant.hpp index 26a598db55bc19c5ce9e1035eeef2add79fd1f35..83ad7b012ca4fa40c75bd2ef9786eeff5a8bdf6d 100644 --- a/example/ck_tile/12_smoothquant/smoothquant.hpp +++ b/example/ck_tile/12_smoothquant/smoothquant.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,21 +14,21 @@ struct SmoothquantTypeConfig; template <> struct SmoothquantTypeConfig { - using XDataType = ck_tile::half_t; - using XScaleDataType = float; - using YScaleDataType = float; - using QYDataType = ck_tile::int8_t; - using ComputeDataType = float; + using XDataType = ck_tile::half_t; + using SmoothScaleDataType = float; + using YScaleDataType = float; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; }; template <> struct SmoothquantTypeConfig { - using XDataType = ck_tile::bf16_t; - using XScaleDataType = float; - using YScaleDataType = float; - using QYDataType = ck_tile::int8_t; - using ComputeDataType = float; + using XDataType = ck_tile::bf16_t; + using SmoothScaleDataType = float; + using YScaleDataType = float; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; }; // runtime args diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp index 25e99c5306fddc9786d06d1902ebcdf0171413f9..723fb3f69f1e70877d053b59cdb3ea25864089c3 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp @@ -3,18 +3,42 @@ #include "moe_sorting_api.hpp" -#define MOE_SORTING_DISPATCH(unroll_num_) \ - constexpr ck_tile::index_t unroll_num = unroll_num_; \ - using ms_problem = ck_tile::MoeSortingProblem; \ - using kernel = ck_tile::MoeSortingKernel; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - const auto lds_bytes = kernel::GetSmemSize(a); \ - float ave_time = ck_tile::launch_kernel( \ - s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ +#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr ck_tile::index_t expert_tile = expert_tile_; \ + using ms_problem = \ + ck_tile::MoeSortingProblem; \ + using kernel = ck_tile::MoeSortingKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_bytes = kernel::GetSmemSize(a); \ + float ave_time = ck_tile::launch_kernel( \ + s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ return ave_time; +#define MOE_SORTING_DISPATCH(unroll_num_) \ + if(a.num_experts <= 8) \ + { \ + MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \ + } \ + else if(a.num_experts <= 16) \ + { \ + MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \ + } \ + else if(a.num_experts <= 32) \ + { \ + MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \ + } \ + else if(a.num_experts <= 64) \ + { \ + MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \ + } + float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) { if(t.weight_type == "fp32" && t.index_type == "int32") @@ -49,21 +73,12 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi case(6): { MOE_SORTING_DISPATCH(6); } - case(7): { - MOE_SORTING_DISPATCH(7); - } case(8): { MOE_SORTING_DISPATCH(8); } - case(9): { - MOE_SORTING_DISPATCH(9); - } case(10): { MOE_SORTING_DISPATCH(10); } - case(11): { - MOE_SORTING_DISPATCH(11); - } default: { MOE_SORTING_DISPATCH(4); } diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp index 91b54932cec5bcb94a90374259619c9bfedbb7f0..0cb393f7dedd450d03fdadcaac1c929791ff4ef8 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp @@ -5,7 +5,7 @@ #include #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" -#include "ck_tile/ops/moe_sorting.hpp" +#include "ck_tile/ops/fused_moe.hpp" struct moe_sorting_trait { diff --git a/example/ck_tile/13_moe_sorting/script/smoke_test.sh b/example/ck_tile/13_moe_sorting/script/smoke_test.sh index 1fc5eafcb005e8ce68832f94b28140fc6d848d8b..3ff8a7332daa45d8882430a50e20c7ed86a9454a 100644 --- a/example/ck_tile/13_moe_sorting/script/smoke_test.sh +++ b/example/ck_tile/13_moe_sorting/script/smoke_test.sh @@ -16,4 +16,5 @@ $EXE -t=127 -e=99 -k=19 $EXE -t=71 -e=11 -k=11 $EXE -t=1 -e=1 -k=1 $EXE -t=99 -e=2 -k=1 -$EXE -t=333 -e=99 -k=13 \ No newline at end of file +$EXE -t=333 -e=99 -k=13 +$EXE -t=128 -e=32 -k=5 -moe_buf_size=262144 diff --git a/example/ck_tile/14_moe_smoothquant/CMakeLists.txt b/example/ck_tile/14_moe_smoothquant/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..12224a39a2c83607b6bbae5c700a1e14571871fb --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/CMakeLists.txt @@ -0,0 +1,25 @@ +function (add_moe_smoothquant_example TARGET_NAME MAIN_SRC) + message("adding ${TARGET_NAME}") + # not using add_example_executable() to add target, since we don't want this to have + # to be included in "make all/install/check" + add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC}) + target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) + + foreach(source IN LISTS ARGN) + list(APPEND INSTANCE_SRCS ${source}) + endforeach() + + target_sources(${TARGET_NAME} PRIVATE ${INSTANCE_SRCS}) + + set(COMPILE_OPTIONS) + # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations + list(APPEND COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + # list(APPEND COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) + + target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS}) +endfunction(add_moe_smoothquant_example TARGET_NAME MAIN_SRC) + +file(GLOB INSTANCE_SRCS instances/*.cpp) + +add_moe_smoothquant_example(tile_example_moe_smoothquant moe_smoothquant.cpp ${INSTANCE_SRCS}) + diff --git a/example/ck_tile/14_moe_smoothquant/README.md b/example/ck_tile/14_moe_smoothquant/README.md new file mode 100644 index 0000000000000000000000000000000000000000..599b4c348966db80cc9c0195abfcd93f5aaf6b3d --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/README.md @@ -0,0 +1,15 @@ +# moe-smoothquant + +This folder contains example for moe-smoothquant using ck_tile tile-programming implementation. +![](misc/moe-sm.png) + +Unlike standard smoothquant op, the input scale is from different expert `[expert, hidden]`, we need reuse the `topk-id` from previous `topk-softmax` and select the corresponding `expert` from current topk, and expand the output/per-token-scale by `topk` + +## build +``` +# in the root of ck_tile +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_example_moe_smoothquant -j +``` +This will result in an executable `build/bin/tile_example_moe_smoothquant` diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n1024_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n1024_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..39481e2c83c8129d0b8a21de3f9183e97bd932aa --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n1024_instance.cpp @@ -0,0 +1,27 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "moe_smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +#if 0 +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +#endif + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n1536_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n1536_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6feccbdaff3959fc246bb88cfaf2669bd6f11398 --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n1536_instance.cpp @@ -0,0 +1,18 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "moe_smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n2048_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n2048_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0e2c9366338ed53c07f134d08437b5618a3fd058 --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n2048_instance.cpp @@ -0,0 +1,19 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "moe_smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n256_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n256_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..373cb0352b6b13d7a9089c8099a36c640656701b --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n256_instance.cpp @@ -0,0 +1,16 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "moe_smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n3072_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n3072_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c0c778f36c98d102b3a15bd0d9c14380892626df --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n3072_instance.cpp @@ -0,0 +1,18 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "moe_smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n4096_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n4096_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..47cffd5fc2f067e176ed77e5f59d3f7eef113893 --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n4096_instance.cpp @@ -0,0 +1,18 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "moe_smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n4096_tp_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n4096_tp_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..726d6018a6bf99073a878ffe702f1a20f765b9eb --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n4096_tp_instance.cpp @@ -0,0 +1,18 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "moe_smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n512_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n512_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6026d509d0c28825a82b571deb68184f2d6501ac --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n512_instance.cpp @@ -0,0 +1,18 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "moe_smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n64_n128_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n64_n128_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3924662fe530c27b89e173edd282939b0531d0bd --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n64_n128_instance.cpp @@ -0,0 +1,16 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "moe_smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n768_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n768_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..00d5c980d7ab0eed32d4c8582c5c09a6b74b8ac6 --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n768_instance.cpp @@ -0,0 +1,16 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "moe_smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n1024_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n1024_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c908739efa57f00848e9f14e402db130b329d89b --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n1024_instance.cpp @@ -0,0 +1,27 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "moe_smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +#if 0 +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +#endif + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n1536_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n1536_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..65e9470cdeb62723c8e2114691947e9b924ab078 --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n1536_instance.cpp @@ -0,0 +1,18 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "moe_smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n2048_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n2048_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..421352f45ffbc3fa30ac92b1347d928417150138 --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n2048_instance.cpp @@ -0,0 +1,18 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "moe_smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n256_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n256_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f102cb6d60c60a46e94bc423394aa6f86b9c35e8 --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n256_instance.cpp @@ -0,0 +1,16 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "moe_smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n3072_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n3072_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ad7b9e3d158641b2b91458c3aea32e1479ae4999 --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n3072_instance.cpp @@ -0,0 +1,18 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "moe_smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n4096_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n4096_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bb79ec7ab4222c07cc0f27c08b756a90de7b95d4 --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n4096_instance.cpp @@ -0,0 +1,18 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "moe_smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n4096_tp_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n4096_tp_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..766c60689f7884341c56d4336f56722d9012548d --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n4096_tp_instance.cpp @@ -0,0 +1,18 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "moe_smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n512_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n512_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6c24e1ebe014d51aa519853a2f6430da7327d1ce --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n512_instance.cpp @@ -0,0 +1,18 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "moe_smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n64_n128_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n64_n128_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..df785eefeff22b0fcbf2208c36c7ea5f592d5484 --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n64_n128_instance.cpp @@ -0,0 +1,16 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "moe_smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n768_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n768_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d89f1c3bbf625da28c12f9b8f28c22a1fa47d4e4 --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n768_instance.cpp @@ -0,0 +1,16 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "moe_smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fwd_api.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fwd_api.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9d86c54b1ad7ad538c0da2de8816fa78ef705185 --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fwd_api.cpp @@ -0,0 +1,155 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "moe_smoothquant.hpp" + +template +using trait_ = moe_smoothquant_traits_; + +template +float moe_smoothquant_dispatch(moe_smoothquant_traits /*t*/, + moe_smoothquant_args a, + const ck_tile::stream_config& s) +{ + float r = -1; + // clang-format off + // rm rn tm tn vn pd 2p + if(a.hidden_size <= 64) { + r = moe_smoothquant_>(s, a); + } + else if(a.hidden_size <= 128) { + if (a.hidden_size % 2 == 0) + r = moe_smoothquant_>(s, a); + else + r = moe_smoothquant_>(s, a); + } + else if(a.hidden_size <= 256) { + if (a.hidden_size % 4 == 0) + r = moe_smoothquant_>(s, a); + else if (a.hidden_size % 2 == 0) + r = moe_smoothquant_>(s, a); + else + r = moe_smoothquant_>(s, a); + } + else if(a.hidden_size <= 512) { + if (a.hidden_size % 8 == 0) + r = moe_smoothquant_>(s, a); + else if (a.hidden_size % 4 == 0) + r = moe_smoothquant_>(s, a); + else if (a.hidden_size % 2 == 0) + r = moe_smoothquant_>(s, a); + else + r = moe_smoothquant_>(s, a); + } + else if(a.hidden_size <= 768) { + if (a.hidden_size % 4 == 0) + r = moe_smoothquant_>(s, a); + else if (a.hidden_size % 2 == 0) + r = moe_smoothquant_>(s, a); + else + r = moe_smoothquant_>(s, a); + } + else if(a.hidden_size <= 1024) { + if (a.hidden_size % 8 == 0) + r = moe_smoothquant_>(s, a); + else if (a.hidden_size % 4 == 0) + r = moe_smoothquant_>(s, a); + else if (a.hidden_size % 2 == 0) + r = moe_smoothquant_>(s, a); + else + r = moe_smoothquant_>(s, a); + } + else if(a.hidden_size <= 1536) { + if (a.hidden_size % 8 == 0) + r = moe_smoothquant_>(s, a); + else if (a.hidden_size % 4 == 0) + r = moe_smoothquant_>(s, a); + else if (a.hidden_size % 2 == 0) + r = moe_smoothquant_>(s, a); + else + r = moe_smoothquant_>(s, a); + } + else if(a.hidden_size <= 2048) { + if (a.hidden_size % 8 == 0) + r = moe_smoothquant_>(s, a); + else if (a.hidden_size % 4 == 0) + r = moe_smoothquant_>(s, a); + else if (a.hidden_size % 2 == 0) + r = moe_smoothquant_>(s, a); + else + r = moe_smoothquant_>(s, a); + } + else if(a.hidden_size <= 3072) { + if (a.hidden_size % 8 == 0) + r = moe_smoothquant_>(s, a); + else if (a.hidden_size % 4 == 0) + r = moe_smoothquant_>(s, a); + else if (a.hidden_size % 2 == 0) + r = moe_smoothquant_>(s, a); + else + r = moe_smoothquant_>(s, a); + } + else if(a.hidden_size <= 4096) { + if (a.hidden_size % 8 == 0) + r = moe_smoothquant_>(s, a); + else if (a.hidden_size % 4 == 0) + r = moe_smoothquant_>(s, a); + else if (a.hidden_size % 2 == 0) + r = moe_smoothquant_>(s, a); + else + r = moe_smoothquant_>(s, a); + } + else if(a.hidden_size > 4096) { + if (a.hidden_size % 8 == 0) + r = moe_smoothquant_>(s, a); + else if (a.hidden_size % 4 == 0) + r = moe_smoothquant_>(s, a); + else if (a.hidden_size % 2 == 0) + r = moe_smoothquant_>(s, a); + else + r = moe_smoothquant_>(s, a); + } + return r; + // clang-format on +} + +float moe_smoothquant(moe_smoothquant_traits t, + moe_smoothquant_args a, + const ck_tile::stream_config& s) +{ + if(t.in_type.compare("fp16") == 0 && t.out_type == "int8") + { + return moe_smoothquant_dispatch(t, a, s); + } + else if(t.in_type.compare("fp16") == 0 && t.out_type == "fp8") + { + return moe_smoothquant_dispatch(t, a, s); + } + else if(t.in_type.compare("bf16") == 0 && t.out_type == "int8") + { + return moe_smoothquant_dispatch(t, a, s); + } + else if(t.in_type.compare("bf16") == 0 && t.out_type == "fp8") + { + return moe_smoothquant_dispatch(t, a, s); + } + else + throw std::runtime_error("Without supported instances!"); +} diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_instance_common.hpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_instance_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..885d9ff7bf5319d763f90983f78d73a46fca5bb2 --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_instance_common.hpp @@ -0,0 +1,65 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "moe_smoothquant.hpp" +#include + +#pragma once + +using S = ck_tile::stream_config; +using A = moe_smoothquant_args; + +template +using trait_ = moe_smoothquant_traits_; + +template +float moe_smoothquant_(const S& s, A a) +{ + using InputType = typename Traits_::InputType; + using OutputType = typename Traits_::OutputType; + + using PipelineProblem = ck_tile::SmoothquantPipelineProblem< + typename MoeSmoothquantTypeConfig::XDataType, + typename MoeSmoothquantTypeConfig::SmoothScaleDataType, + typename MoeSmoothquantTypeConfig::ComputeDataType, + typename MoeSmoothquantTypeConfig::YScaleDataType, + typename MoeSmoothquantTypeConfig::QYDataType, + typename Traits_::Shape, + Traits_::kPadN, + Traits_::kTwoPass>; + + using OnePassPipeline = ck_tile::SmoothquantPipelineOnePass; + using TwoPassPipeline = ck_tile::SmoothquantPipelineTwoPass; + using Pipeline = std::conditional_t; + + using Kernel = ck_tile::MoeSmoothquant; + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + + auto kargs = Kernel::MakeKargs(a); + if(s.log_level_ > 0) + std::cout << ", " << Kernel::GetName() << std::flush; + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); +} diff --git a/example/ck_tile/14_moe_smoothquant/misc/moe-sm.png b/example/ck_tile/14_moe_smoothquant/misc/moe-sm.png new file mode 100644 index 0000000000000000000000000000000000000000..5a40099ef3ce3860ed133e4b150ad4785108f129 Binary files /dev/null and b/example/ck_tile/14_moe_smoothquant/misc/moe-sm.png differ diff --git a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dc5b397c854fbedaa38d5dbf8f395ab3ecbbb186 --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp @@ -0,0 +1,276 @@ +#include "ck_tile/host.hpp" +#include "moe_smoothquant.hpp" +#include +#include + +// different threshold for different dtype +template +auto get_elimit() +{ + double rtol = 1e-5; + double atol = 1e-5; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + double rtol = 1e-5; + double atol = 1e-5; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + // due to rounding, int8 quantization might have 1 abs error + double rtol = 1; + double atol = 1; + return ck_tile::make_tuple(rtol, atol); +} + +template +void topid_unique_gen( + std::vector& host_tensor, int tokens, int topk, int num_expert, int seed) +{ + size_t total_size = topk * tokens; + std::srand(seed); + std::set unique_set; + IndexType current_v; + for(size_t i = 0; i < total_size; i++) + { + if(i % topk == 0) + { + unique_set.clear(); + } + current_v = std::rand() % num_expert; + while(unique_set.find(current_v) != unique_set.end()) + { + current_v = std::rand() % num_expert; + } + unique_set.insert(current_v); + host_tensor[i] = current_v; + } +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("t", "3328", "tokens dimension") + .insert("h", "4096", "hidden_size dimension") + .insert("e", "32", "experts") + .insert("k", "5", "topk") + .insert("stride", "-1", "stride per row, if -1 then equal to hidden_size") + .insert("v", "1", "cpu validation or not") + .insert("kname", "1", "print kernel name or not") + .insert("prec_i", "fp16", "input precision, fp16/bf16") + .insert("prec_o", "int8", "precision, int8/fp8") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t tokens = arg_parser.get_int("t"); + ck_tile::index_t hidden_size = arg_parser.get_int("h"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + if(stride < 0) + stride = hidden_size; + ck_tile::index_t experts = arg_parser.get_int("e"); + ck_tile::index_t topk = arg_parser.get_int("k"); + std::string prec_i = arg_parser.get_str("prec_i"); + std::string prec_o = arg_parser.get_str("prec_o"); + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + assert(stride >= hidden_size); + + using TypeConfig = MoeSmoothquantTypeConfig; + + using XDataType = typename TypeConfig::XDataType; + using SmoothScaleDataType = typename TypeConfig::SmoothScaleDataType; + using YScaleDataType = typename TypeConfig::YScaleDataType; + using QYDataType = typename TypeConfig::QYDataType; + using ComputeDataType = typename TypeConfig::ComputeDataType; + + // host verify + ck_tile::HostTensor x_host({tokens, hidden_size}, {stride, 1}); + ck_tile::HostTensor smscale_host({experts * hidden_size}); + ck_tile::HostTensor topk_ids_host({tokens, topk}); + + ck_tile::HostTensor yscale_host_ref({topk * tokens}, {1}); + ck_tile::HostTensor yscale_host_dev({topk * tokens}, {1}); + + ck_tile::HostTensor qy_host_ref({topk * tokens, hidden_size}, {stride, 1}); + ck_tile::HostTensor qy_host_dev({topk * tokens, hidden_size}, {stride, 1}); + + topid_unique_gen(topk_ids_host.mData, tokens, topk, experts, 11937); + ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); + ck_tile::FillUniformDistribution{1e-3, .5f}(smscale_host); + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem topk_ids_buf(topk_ids_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + smscale_buf.ToDevice(smscale_host.data()); + topk_ids_buf.ToDevice(topk_ids_host.data()); + + std::cout << "[" << prec_i << "-" << prec_o << "]" + << " tokens:" << tokens << ", hidden_size:" << hidden_size << ", stride:" << stride + << ", experts:" << experts << ", topk:" << topk << std::flush; + + moe_smoothquant_traits traits{prec_i, prec_o}; + + moe_smoothquant_args args{x_buf.GetDeviceBuffer(), + smscale_buf.GetDeviceBuffer(), + topk_ids_buf.GetDeviceBuffer(), + yscale_buf.GetDeviceBuffer(), + qy_buf.GetDeviceBuffer(), + tokens, + hidden_size, + experts, + topk, + stride, + stride}; + + float ave_time = moe_smoothquant( + traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); + + std::size_t num_byte = sizeof(XDataType) * tokens * hidden_size + + sizeof(SmoothScaleDataType) * topk * hidden_size + + sizeof(YScaleDataType) * topk * tokens + + sizeof(QYDataType) * topk * tokens * hidden_size; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; + + bool pass = true; + + if(do_validation) + { + using YDataType = ComputeDataType; + ck_tile::HostTensor y_host({topk * tokens, hidden_size}, {stride, 1}); + // smooth outlier + { + auto f = [&](auto i_token) { + for(int i_topk = 0; i_topk < topk; i_topk++) + { + auto i_expert = topk_ids_host(i_token, i_topk); + + for(int i_h = 0; i_h < hidden_size; ++i_h) + { + auto v_smscale = ck_tile::type_convert( + smscale_host(i_expert * hidden_size + i_h)); + auto v_x = ck_tile::type_convert(x_host(i_token, i_h)); + // y_host(i_token * topk + i_topk, i_h) = v_x * v_smscale; + y_host(i_topk * tokens + i_token, i_h) = v_x * v_smscale; + } + } + }; + + ck_tile::make_ParallelTensorFunctor(f, tokens)(std::thread::hardware_concurrency()); + } + + // yscale + { + ck_tile::HostTensor y_rowwise_amax_host({topk * tokens}); + + using ReduceAmax = ck_tile::ReduceOp::AbsMax; + ck_tile::reference_reduce( + y_host, y_rowwise_amax_host, ReduceAmax{}); + + auto op = [](const auto& v0) { + return v0 / + ck_tile::type_convert(ck_tile::numeric::max()); + }; + ck_tile::reference_unary_elementwise( + y_rowwise_amax_host, yscale_host_ref, op); + + yscale_buf.FromDevice(yscale_host_dev.mData.data()); + + auto [rtol, atol] = get_elimit(); + pass &= ck_tile::check_err(yscale_host_dev, + yscale_host_ref, + std::string("yscale Error: Incorrect results!"), + rtol, + atol); + } + + // rowwise quantization + { + ck_tile::reference_rowwise_quantization2d( + y_host, yscale_host_ref, qy_host_ref); + + qy_buf.FromDevice(qy_host_dev.data()); + auto [rtol, atol] = get_elimit(); + + if(stride == hidden_size) + { + pass = ck_tile::check_err(qy_host_dev, + qy_host_ref, + std::string("qy Error: Incorrect results!"), + rtol, + atol); + } + else + { + for(int i_r = 0; i_r < topk * tokens; i_r++) + { + std::vector qy_host_dev_row(qy_host_dev.begin() + i_r * stride, + qy_host_dev.begin() + i_r * stride + + hidden_size); + std::vector qy_host_ref_row(qy_host_ref.begin() + i_r * stride, + qy_host_ref.begin() + i_r * stride + + hidden_size); + pass &= ck_tile::check_err(qy_host_dev_row, + qy_host_ref_row, + std::string("qy[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string prec_i = arg_parser.get_str("prec_i"); + const std::string prec_o = arg_parser.get_str("prec_o"); + if(prec_i == "fp16" && prec_o == "int8") + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "fp16" && prec_o == "fp8") + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "int8") + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "fp8") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c1b90b14b2e45e618eff3d3ac9d310ab48dd6d3a --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/smoothquant.hpp" +#include + +template +struct MoeSmoothquantTypeConfig +{ + using XDataType = InputType; + using SmoothScaleDataType = float; + using YScaleDataType = float; + using QYDataType = OutputType; + using ComputeDataType = float; +}; + +// runtime args +struct moe_smoothquant_args : public ck_tile::MoeSmoothquantHostArgs +{ +}; + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct moe_smoothquant_traits_ +{ + using InputType = ck_tile::remove_cvref_t; + using OutputType = ck_tile::remove_cvref_t; + + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr ck_tile::index_t total_warps = + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + + // num of warps along m + static constexpr ck_tile::index_t BlockWarps_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return total_warps * (warpSize / ThreadPerBlock_N_); + } + else + { + // static_assert(warpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / warpSize); + } + }(); + + // num of warps along n + static constexpr ck_tile::index_t BlockWarps_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N_ % warpSize == 0); + return ThreadPerBlock_N_ / warpSize; + } + }(); + + static constexpr ck_tile::index_t Repeat_M = Repeat_M_; + static constexpr ck_tile::index_t Repeat_N = Repeat_N_; + + static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; + static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; + + static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; + static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + + using BlockTile = ck_tile::sequence; + using BlockWarps = ck_tile::sequence; + using WarpTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + + using Shape = ck_tile::Generic2dBlockShape; + + static constexpr bool kPadN = kPadN_; + static constexpr bool kTwoPass = kTwoPass_; +}; + +template +float moe_smoothquant_(const ck_tile::stream_config& s, moe_smoothquant_args a); + +// This is the public API, will be generated by script +struct moe_smoothquant_traits +{ + std::string in_type; // input type + std::string out_type; // output type +}; + +float moe_smoothquant(moe_smoothquant_traits, moe_smoothquant_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/14_moe_smoothquant/script/perf_test.sh b/example/ck_tile/14_moe_smoothquant/script/perf_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..d1e848b930fc0ca91be56a834d42199d5b072df0 --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/script/perf_test.sh @@ -0,0 +1,37 @@ + +EXE=build/bin/tile_example_moe_smoothquant + +$EXE -t=1 -h=1 -v=1 -prec=bf16 -repeat=1000 +$EXE -t=700 -h=80 -v=1 -prec=bf16 -repeat=1000 +$EXE -t=700 -h=128 -v=1 -prec=bf16 -repeat=1000 +$EXE -t=700 -h=144 -v=1 -prec=bf16 -repeat=1000 +$EXE -t=700 -h=168 -v=1 -prec=bf16 -repeat=1000 +$EXE -t=700 -h=184 -v=1 -prec=bf16 -repeat=1000 +$EXE -t=700 -h=256 -v=1 -prec=bf16 -repeat=1000 +$EXE -t=700 -h=288 -v=1 -prec=bf16 -repeat=1000 +$EXE -t=700 -h=344 -v=1 -prec=bf16 -repeat=1000 +$EXE -t=700 -h=376 -v=1 -prec=bf16 -repeat=1000 +$EXE -t=700 -h=448 -v=1 -prec=bf16 -repeat=1000 +$EXE -t=700 -h=512 -v=1 -prec=bf16 -repeat=1000 +$EXE -t=700 -h=924 -v=1 -prec=bf16 -repeat=1000 +$EXE -t=700 -h=1024 -v=1 -prec=bf16 -repeat=1000 +$EXE -t=700 -h=1078 -v=1 -prec=bf16 -repeat=1000 +$EXE -t=700 -h=1996 -v=1 -prec=bf16 -repeat=1000 +$EXE -t=700 -h=4080 -v=1 -prec=bf16 -repeat=1000 + +$EXE -t=700 -h=80 -v=1 -prec=fp16 -repeat=1000 +$EXE -t=700 -h=128 -v=1 -prec=fp16 -repeat=1000 +$EXE -t=700 -h=144 -v=1 -prec=fp16 -repeat=1000 +$EXE -t=700 -h=168 -v=1 -prec=fp16 -repeat=1000 +$EXE -t=700 -h=184 -v=1 -prec=fp16 -repeat=1000 +$EXE -t=700 -h=256 -v=1 -prec=fp16 -repeat=1000 +$EXE -t=700 -h=288 -v=1 -prec=fp16 -repeat=1000 +$EXE -t=700 -h=344 -v=1 -prec=fp16 -repeat=1000 +$EXE -t=700 -h=376 -v=1 -prec=fp16 -repeat=1000 +$EXE -t=700 -h=448 -v=1 -prec=fp16 -repeat=1000 +$EXE -t=700 -h=512 -v=1 -prec=fp16 -repeat=1000 +$EXE -t=700 -h=924 -v=1 -prec=fp16 -repeat=1000 +$EXE -t=700 -h=1024 -v=1 -prec=fp16 -repeat=1000 +$EXE -t=700 -h=1078 -v=1 -prec=fp16 -repeat=1000 +$EXE -t=700 -h=1996 -v=1 -prec=fp16 -repeat=1000 +$EXE -t=700 -h=4080 -v=1 -prec=fp16 -repeat=1000 \ No newline at end of file diff --git a/example/ck_tile/14_moe_smoothquant/script/smoke_test.sh b/example/ck_tile/14_moe_smoothquant/script/smoke_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..e01f3de10a017ea2a9a1fd54bb55d6f9e7a970b5 --- /dev/null +++ b/example/ck_tile/14_moe_smoothquant/script/smoke_test.sh @@ -0,0 +1,32 @@ +#!/bin/sh +EXE=build/bin/tile_example_moe_smoothquant + +for pr_i in "fp16" "bf16" ; do +for pr_o in "int8" "fp8" ; do +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=99 -h=13 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=17 -h=16 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=1 -h=100 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=4 -h=128 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=80 -h=127 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=22 -h=255 -stride=256 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=7 -h=599 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=19 -h=512 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=33 -h=313 -stride=1000 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=11 -h=510 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=171 -h=676 -stride=818 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=91 -h=636 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=12 -h=768 -stride=800 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=100 -h=766 -stride=812 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=31 -h=1024 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=64 -h=1000 -stride=1004 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=8 -h=1501 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=3 -h=1826 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=5 -h=2040 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=7 -h=2734 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=1 -h=3182 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=9 -h=4096 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=3 -h=8192 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=1 -h=10547 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=3 -h=17134 +done +done diff --git a/example/ck_tile/15_fused_moe/CMakeLists.txt b/example/ck_tile/15_fused_moe/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a716eef19ec38f930728ab2667de18f2340ffd05 --- /dev/null +++ b/example/ck_tile/15_fused_moe/CMakeLists.txt @@ -0,0 +1,19 @@ +set(TILE_EXAPMLE_FUSED_MOE "tile_example_fused_moe") +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding ${TILE_EXAPMLE_FUSED_MOE}") +file(GLOB INSTANCE_SRCS instances/*.cpp) +add_executable(${TILE_EXAPMLE_FUSED_MOE} EXCLUDE_FROM_ALL main.cpp) +target_include_directories(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${INSTANCE_SRCS}) + +set(TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_BUFFER_LOAD_AGPR=1) # TODO: enable load to a +list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=4) # rta +# list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -mllvm -greedy-reverse-local-assignment=1) +# list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) + +target_compile_options(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS}) diff --git a/example/ck_tile/15_fused_moe/README.md b/example/ck_tile/15_fused_moe/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b6ceabf3515e7861a7c43defacb1a3eb5b626907 --- /dev/null +++ b/example/ck_tile/15_fused_moe/README.md @@ -0,0 +1,72 @@ +# fused-moe +Implementing the fused-moe block operator using ck-tile. This is a scatter/gather-group-gemm based solution, similiar to that of [vllm moe](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), but we introduce more kernel fusion to boost performance +![](misc/moe-0.png) + +The benifit of this fused-moe: +* 1.5~2x perf boost compared with current vllm solution +* zero workspace to reduce memory footprint +* much less kernel instance, easy to maintain + +# Implementation and feature support +## NOTES: +currently gate+up in fp16 case will very easily cause accumulator overflow the fp16 max(65504), hence result in INF. Please use BF16 for gate+up case, API side will have no check for this. + +## moe-sorting +this is a common pre-process step before the actual moe-gemm. The purpose is to transform the moe loop over from token-by-token to expert-by-expert, make sure very workgroup is working for a single expert (B matrix). Besides, we extend this op to do the zeroing of the output buffer(to be used for reduce buffer with atomic) + +## moe-gemm +`moe-gemm` is a group-gemm based back-to-back gemm, where the row-id of input token comes from another buffer. Naive understanding of fused-moe is from token-by-token view as below picture: +![](misc/moe-1.png) +After `moe-sorting`, we can view this algorithm as expert-by-expert, as below: +![](misc/moe-2.png) + +## optimization +summary of the key design of this fused-moe operator: +* fuse 2 group-gemm + activation + `topk-weight` multiply into single kernel, using atomic for 2nd gemm accumualation +* fuse buffer-zeroing in `moe-sorgin`, user no longer need call extra torch.zero() for the out buffer +* fused scatter-gather for row index(same as vllm) +* pre-shuffle B matric(weight) to maximize memory throughput. input(activation) keep original layout `[batch, hidden]`. +* extrem optimized pipeline using block-inline-asm(we call it `micro-kernel` or `uk`), while not breaking the *composable* design of ck + +## +``` +// [indexing implementation-1] +// using M_a as constexpr block_size to partition all tokens into different slices +// each slice map to one expert, and one expert can have multiple slices +// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5 +// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]] +// tok-0 tok-1 tok-2 tok-3 tok-4 +// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number) +// +// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]] +// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 +// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] +// +// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1) +// * this could be larger than actual, since actual tokens are on GPU +// +// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] +// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -| +// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o] +// +// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr +// +// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5] +// * length is (max_num_tokens_padded + block_size - 1) / block_size +// +// num_tokens_post_padded_ptr : [28] +// num_sorted_tiles_ptr : [7] +// +// * different from vLLM +// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id +// 2)need sorted_weight_ptr +// 3) use num_sorted_tiles_ptr, already divided by M_a +// +// * below used for indexing +// 1) sorted_token_ids_ptr [max_num_tokens_padded] +// 2) sorted_weight_ptr +// 3) sorted_expert_ids_ptr +// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one) +// +// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1) +``` \ No newline at end of file diff --git a/example/ck_tile/15_fused_moe/fused_moe.hpp b/example/ck_tile/15_fused_moe/fused_moe.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9c4e7b09cab9d7d7cb36f394de828a92815b672d --- /dev/null +++ b/example/ck_tile/15_fused_moe/fused_moe.hpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "fused_moesorting.hpp" +#include "fused_moegemm.hpp" + +struct fused_moe_args +{ + const void* a_ptr; // [m, k], input token + const void* a_scale_ptr; // [m, 1], token scale + const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) + const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w]) + const void* g_scale_ptr; // [e, 1, n], gate(up) scale + const void* d_scale_ptr; // [e, 1, k], down scale + const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input + void* o_ptr; // [m, k], output token (no need to do zeroing) + + const void* topk_ids_ptr; // [tokens, topk] + const void* topk_weight_ptr; // [tokens, topk] + void* sorted_token_ids_ptr; // [max_num_tokens_padded] + void* sorted_weight_ptr; // [max_num_tokens_padded] + void* sorted_expert_ids_ptr; // [(max_num_tokens_padded + block_size - 1) / block_size] + void* num_sorted_tiles_ptr; // [1] + + ck_tile::index_t block_m; // block_m, used to devide the input + ck_tile::index_t hidden_size; // k + ck_tile::index_t intermediate_size; // n / TP, for Gate. and Up, Down is also this value + ck_tile::index_t num_tokens; // input number of tokens for current iteration + ck_tile::index_t num_experts; // number of groups + ck_tile::index_t topk; // need this? + + ck_tile::index_t stride_token; // for input/output, stride for each row, should >= hidden_size +}; + +// This is the public API, will be generated by script +struct fused_moe_traits +{ + std::string prec_i; // input precision + std::string prec_w; // weight precision + std::string prec_o; // output precision + std::string prec_st; // token scale data type + std::string prec_sw; // weight scale data type + std::string prec_sq; // smooth quant scale + std::string prec_kw; // topk-weight data type + int block_m; + int activation; // 0:gelu, 1:silu + int gate_only; // 0:g1u0, 1:g1u1 + int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant +}; + +float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/15_fused_moe/fused_moegemm.hpp b/example/ck_tile/15_fused_moe/fused_moegemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8a1027c80cdea639797613c95c74536fdac2acf0 --- /dev/null +++ b/example/ck_tile/15_fused_moe/fused_moegemm.hpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/fused_moe.hpp" +#include + +// this is only a convenient structure for creating an example +// this is not part of the host API +template +struct FusedMoeGemmTypeConfig; + +template +struct FusedMoeGemmTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using GDataType = ck_tile::bf16_t; + using DDataType = ck_tile::bf16_t; + using AccDataType = float; + using ODataType = ck_tile::bf16_t; + using AScaleDataType = ck_tile::remove_cvref_t; + using GScaleDataType = ck_tile::remove_cvref_t; + using DScaleDataType = ck_tile::remove_cvref_t; + using YSmoothScaleDataType = ck_tile::remove_cvref_t; + using TopkWeightDataType = ck_tile::remove_cvref_t; + using IndexDataType = ck_tile::index_t; +}; + +template +struct FusedMoeGemmTypeConfig +{ + using ADataType = ck_tile::fp16_t; + using GDataType = ck_tile::fp16_t; + using DDataType = ck_tile::fp16_t; + using AccDataType = float; + using ODataType = ck_tile::fp16_t; + using AScaleDataType = ck_tile::remove_cvref_t; + using GScaleDataType = ck_tile::remove_cvref_t; + using DScaleDataType = ck_tile::remove_cvref_t; + using YSmoothScaleDataType = ck_tile::remove_cvref_t; + using TopkWeightDataType = ck_tile::remove_cvref_t; + using IndexDataType = ck_tile::index_t; +}; + +template +struct FusedMoeGemmTypeConfig +{ + using ADataType = ck_tile::int8_t; + using GDataType = ck_tile::int8_t; + using DDataType = ck_tile::int8_t; + using AccDataType = int32_t; + using ODataType = ck_tile::bf16_t; + using AScaleDataType = ck_tile::remove_cvref_t; + using GScaleDataType = ck_tile::remove_cvref_t; + using DScaleDataType = ck_tile::remove_cvref_t; + using YSmoothScaleDataType = ck_tile::remove_cvref_t; + using TopkWeightDataType = ck_tile::remove_cvref_t; + using IndexDataType = ck_tile::index_t; +}; + +// runtime args +struct fused_moegemm_args : public ck_tile::FusedMoeGemmHostArgs +{ +}; + +// This is the public API, will be generated by script +struct fused_moegemm_traits +{ + std::string prec_i; // input precision + std::string prec_w; // weight precision + std::string prec_o; // output precision + std::string prec_st; // token scale data type + std::string prec_sw; // weight scale data type + std::string prec_sq; // smooth quant scale + std::string prec_kw; // topk-weight data type + int block_m; + int activation; // 0:gelu, 1:silu + int gate_only; // 0:g1u0, 1:g1u1 + int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant +}; + +float fused_moegemm(fused_moegemm_traits, fused_moegemm_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/15_fused_moe/fused_moesorting.hpp b/example/ck_tile/15_fused_moe/fused_moesorting.hpp new file mode 100644 index 0000000000000000000000000000000000000000..57dace9b41fff4cf87d7faf42ea19fdfd9a06d26 --- /dev/null +++ b/example/ck_tile/15_fused_moe/fused_moesorting.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/fused_moe.hpp" + +struct fused_moesorting_trait +{ + std::string index_type; + std::string weight_type; // currently always float +}; + +struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs +{ +}; + +float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s); diff --git a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d29e4fd4fd619b002a24e75483a8e407a42d6815 --- /dev/null +++ b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "fused_moe.hpp" + +float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_config& s) +{ + auto s_sub = ck_tile::stream_config{s.stream_id_, false, s.log_level_, 0, 1}; + + auto o_data_bytes = [&]() { + if(t.prec_o == "fp32") + return 4; + else if(t.prec_o == "fp16" || t.prec_o == "bf16") + return 2; + else if(t.prec_o == "int8" || t.prec_o == "fp8") + return 1; + return 1; + }(); + + auto t0 = fused_moesorting_trait{"int32", "fp32"}; + auto a0 = fused_moesorting_args{ + a.topk_ids_ptr, // const void* p_topk_ids; + a.topk_weight_ptr, // const void* p_weights; + a.sorted_token_ids_ptr, // void* p_sorted_token_ids; + a.sorted_weight_ptr, // void* p_sorted_weights; + a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids; + a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad; + a.o_ptr, // void* p_moe_buf; + a.num_tokens, // index_t tokens; + a.block_m, // index_t unit_size; + a.num_experts, // index_t num_experts; + a.topk, // index_t topk; + a.num_tokens * a.stride_token * o_data_bytes // index_t moe_buf_bytes; + }; + + auto t1 = fused_moegemm_traits{t.prec_i, + t.prec_w, + t.prec_o, + t.prec_st, + t.prec_sw, + t.prec_sq, + t.prec_kw, + t.block_m, + t.activation, + t.gate_only, + t.fused_quant}; + auto a1 = fused_moegemm_args{ + a.a_ptr, // const void* a_ptr; + a.a_scale_ptr, // const void* a_scale_ptr; + a.g_ptr, // const void* g_ptr; + a.d_ptr, // const void* d_ptr; + a.g_scale_ptr, // const void* g_scale_ptr; + a.d_scale_ptr, // const void* d_scale_ptr; + a.y_smooth_scale_ptr, // const void* y_smooth_scale_ptr; + a.o_ptr, // void* o_ptr; + a.sorted_token_ids_ptr, // const void* sorted_token_ids_ptr; + a.sorted_weight_ptr, // const void* sorted_weight_ptr; + a.sorted_expert_ids_ptr, // const void* sorted_expert_ids_ptr; + a.num_sorted_tiles_ptr, // const void* num_sorted_tiles_ptr; + a.hidden_size, // index_t hidden_size; + a.intermediate_size, // index_t intermediate_size; + a.num_tokens, // index_t num_tokens; + a.num_experts, // index_t num_experts; + a.topk, // index_t topk; + a.stride_token // index_t stride_token; + }; + + float r0 = -1; + float r1 = -1; + + float r = ck_tile::launch_kernel( + s, + [=, &r0](const ck_tile::stream_config&) { r0 = fused_moesorting(t0, a0, s_sub); }, + [=, &r1](const ck_tile::stream_config&) { r1 = fused_moegemm(t1, a1, s_sub); }); + + // keep unsupported case return negative + if(r0 < 0 || r1 < 0) + return -1; + + return r; +} diff --git a/example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp new file mode 100644 index 0000000000000000000000000000000000000000..49d29bad51c415dc5f0feb288fd4d5bd939f8b30 --- /dev/null +++ b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "fused_moegemm.hpp" +#include "fused_moegemm_api_traits.hpp" + +// Note: this internal API only declare, not define here, otherwise will block `make -j` +template +float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a); + +template +using S = ck_tile::sequence; + +float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile::stream_config& s) +{ + // clang-format off + float r = -1; + if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 0) + { + constexpr ck_tile::index_t act_ = 0; + constexpr ck_tile::index_t go_ = 1; + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>; + r = fused_moegemm_(s, a); + } + else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 0) + { + constexpr ck_tile::index_t act_ = 0; + constexpr ck_tile::index_t go_ = 0; + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>; + r = fused_moegemm_(s, a); + } + else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" && + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 0) + { + constexpr ck_tile::index_t act_ = 0; + constexpr ck_tile::index_t go_ = 1; + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>; + r = fused_moegemm_(s, a); + } + else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" && + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 0) + { + constexpr ck_tile::index_t act_ = 0; + constexpr ck_tile::index_t go_ = 0; + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>; + r = fused_moegemm_(s, a); + } + else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 1) + { + constexpr ck_tile::index_t act_ = 1; + constexpr ck_tile::index_t go_ = 1; + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>; + r = fused_moegemm_(s, a); + } + else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 1) + { + constexpr ck_tile::index_t act_ = 1; + constexpr ck_tile::index_t go_ = 0; + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>; + r = fused_moegemm_(s, a); + } + else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" && + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 1) + { + constexpr ck_tile::index_t act_ = 1; + constexpr ck_tile::index_t go_ = 1; + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>; + r = fused_moegemm_(s, a); + } + else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" && + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 1) + { + constexpr ck_tile::index_t act_ = 1; + constexpr ck_tile::index_t go_ = 0; + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>; + r = fused_moegemm_(s, a); + } + // clang-format on + return r; +} diff --git a/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp new file mode 100644 index 0000000000000000000000000000000000000000..343ddbed13ab4873b0120e0fb7fdeff8d8bfcc35 --- /dev/null +++ b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "fused_moegemm_api_traits.hpp" +#include "ck_tile/ops/fused_moe.hpp" +#include + +template +using S = ck_tile::sequence; + +// do not the define of this tepmlate function inside the _api.cpp, otherwise will block make -j +template +float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) +{ + using f_traits = ck_tile::FusedMoeGemmTraits; + using f_shape = ck_tile::FusedMoeGemmShape; + + constexpr auto get_activation_ = []() { + if constexpr(Ts_::Activation == 0) + { + return ck_tile::element_wise::FastGeluAsm{}; + } + else + return ck_tile::element_wise::Silu{}; + }; + using f_act_ = ck_tile::remove_cvref_t; + + using f_problem = ck_tile::FusedMoeGemmPipelineProblem; + + // using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx; + using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk; + using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear; + using f_kernel = ck_tile::FusedMoeGemmKernel; + + const dim3 grids = f_kernel::GridSize(a); + constexpr dim3 blocks = f_kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + + static int printed = 0; + + auto kargs = f_kernel::MakeKargs(a); + if(s.log_level_ > 0 && printed == 0) + { + std::cout << ", " << f_kernel::GetName() << std::flush; + printed = 1; + } + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(f_kernel{}, grids, blocks, 0, kargs)); +} diff --git a/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a7e53cc6548f2d3ddf3590b40c709549b7cf8d4d --- /dev/null +++ b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template + typename WarpPerBlock_, + typename WarpTile_, // seq<*,*,*>, used to select mfma + ck_tile::index_t Activation_ = 0, // 0: Gelu 1: Silu + ck_tile::index_t GateOnly_ = 0, + ck_tile::index_t FusedQuant_ = 0> +struct fmoe_ // traits, ugly name, only used for internal +{ + using TypeConfig = FusedMoeGemmTypeConfig; + + using ADataType = ck_tile::remove_cvref_t; + using GDataType = ck_tile::remove_cvref_t; + using DDataType = ck_tile::remove_cvref_t; + using AccDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using AScaleDataType = ck_tile::remove_cvref_t; + using GScaleDataType = ck_tile::remove_cvref_t; + using DScaleDataType = ck_tile::remove_cvref_t; + using YSmoothScaleDataType = ck_tile::remove_cvref_t; + using TopkWeightDataType = ck_tile::remove_cvref_t; + using IndexDataType = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t BT_ = BlockTIle_::at(ck_tile::number<0>{}); // block token + static constexpr ck_tile::index_t BI_ = + BlockTIle_::at(ck_tile::number<1>{}); // block intermediate + static constexpr ck_tile::index_t BH_ = BlockTIle_::at(ck_tile::number<2>{}); // block hidden + static constexpr ck_tile::index_t BD_ = BlockTIle_::at(ck_tile::number<3>{}); // block down + + using BlockTile_0 = ck_tile::sequence; + using WarpPerBlock_0 = ck_tile::remove_cvref_t; + using WarpTile_0 = ck_tile::remove_cvref_t; + + using BlockTile_1 = ck_tile::sequence; + using WarpPerBlock_1 = ck_tile::remove_cvref_t; + using WarpTile_1 = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t Activation = Activation_; // 0: Gelu 1: Silu + static constexpr ck_tile::index_t GateOnly = GateOnly_; + static constexpr ck_tile::index_t FusedQuant = FusedQuant_; +}; diff --git a/example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp b/example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5691743565ca5ffd67fbaccf3e227310b7c7f143 --- /dev/null +++ b/example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "fused_moegemm.hpp" +#include "fused_moegemm_api_traits.hpp" +#include "fused_moegemm_api_internal.hpp" + +// clang-format off +template float fused_moegemm_< + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 0, 0, 0> +>(const ck_tile::stream_config& s, fused_moegemm_args a); + +template float fused_moegemm_< + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 0, 1, 0> +>(const ck_tile::stream_config& s, fused_moegemm_args a); + +template float fused_moegemm_< + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0, 0> +>(const ck_tile::stream_config& s, fused_moegemm_args a); + +template float fused_moegemm_< + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 1, 0> +>(const ck_tile::stream_config& s, fused_moegemm_args a); +// clang-format on diff --git a/example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp b/example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..74632df415ab28e0554b5f21c98a2b293c27e107 --- /dev/null +++ b/example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "fused_moegemm.hpp" +#include "fused_moegemm_api_traits.hpp" +#include "fused_moegemm_api_internal.hpp" + +// clang-format off +template float fused_moegemm_< + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 0, 0, 0> +>(const ck_tile::stream_config& s, fused_moegemm_args a); + +template float fused_moegemm_< + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 0, 1, 0> +>(const ck_tile::stream_config& s, fused_moegemm_args a); + +template float fused_moegemm_< + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0, 0> +>(const ck_tile::stream_config& s, fused_moegemm_args a); + +template float fused_moegemm_< + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 1, 0> +>(const ck_tile::stream_config& s, fused_moegemm_args a); + +// clang-format on diff --git a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7ca24c5c9a2ce230640e30efa6b8b2ec42f445c4 --- /dev/null +++ b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "fused_moesorting.hpp" + +#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr ck_tile::index_t expert_tile = expert_tile_; \ + using ms_problem = \ + ck_tile::MoeSortingProblem; \ + using kernel = ck_tile::MoeSortingKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_bytes = kernel::GetSmemSize(a); \ + float ave_time = ck_tile::launch_kernel( \ + s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ + return ave_time; + +#define MOE_SORTING_DISPATCH(unroll_num_) \ + if(a.num_experts <= 8) \ + { \ + MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \ + } \ + else if(a.num_experts <= 16) \ + { \ + MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \ + } \ + else if(a.num_experts <= 32) \ + { \ + MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \ + } \ + else if(a.num_experts <= 64) \ + { \ + MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \ + } + +float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s) +{ + if(t.weight_type == "fp32" && t.index_type == "int32") + { + if(a.num_experts > 127) + { + printf("lds size exceed, only support experts <127 \n"); + return -1; + } + if(a.moe_buf_bytes % 16) + { + printf("buf set size %d unaligned, must be multiple of 16\n", a.moe_buf_bytes); + return -1; + } + using index_t = ck_tile::index_t; + using ms_weight_type = float; + index_t smem_io_unroll_num = ck_tile::integer_divide_ceil(a.tokens * a.topk, 64); + switch(smem_io_unroll_num) + { + case(1): { + MOE_SORTING_DISPATCH(1); + } + case(2): { + MOE_SORTING_DISPATCH(2); + } + case(3): { + MOE_SORTING_DISPATCH(3); + } + case(5): { + MOE_SORTING_DISPATCH(5); + } + case(6): { + MOE_SORTING_DISPATCH(6); + } + case(8): { + MOE_SORTING_DISPATCH(8); + } + case(10): { + MOE_SORTING_DISPATCH(10); + } + default: { + MOE_SORTING_DISPATCH(4); + } + } + } + return -1; +} diff --git a/example/ck_tile/15_fused_moe/main.cpp b/example/ck_tile/15_fused_moe/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..51611a67bc6558ad1a990858c4bb8fb22fd46f23 --- /dev/null +++ b/example/ck_tile/15_fused_moe/main.cpp @@ -0,0 +1,606 @@ +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "fused_moe.hpp" + +// different threshold for different dtype +template +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +// mfma_type, 0:32x32, 1:16x16 +// TODO: padding? +template +auto shuffle_moe_weight(const ck_tile::HostTensor& t, std::string mfma_dtype, int mfma_type = 0) +{ + assert(t.get_lengths().size() == 3); + int b_ = t.get_lengths()[0]; + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[2]; + if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0) + { + ck_tile::HostTensor t_view({b_, n_ / 32, 32, k_ / 16, 2, 8}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5}); + } + else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1) + { + ck_tile::HostTensor t_view({b_, n_ / 16, 16, k_ / 32, 4, 8}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5}); + } + else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 0) + { + ck_tile::HostTensor t_view({b_, n_ / 32, 32, k_ / 32, 2, 16}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5}); + } + else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 1) + { + ck_tile::HostTensor t_view({b_, n_ / 16, 16, k_ / 64, 4, 16}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5}); + } + return t; +} + +template +void topid_unique_gen( + std::vector& host_tensor, int tokens, int topk, int num_expert, int seed) +{ + size_t total_size = topk * tokens; + std::srand(seed); + std::set unique_set; + IndexType current_v; + for(size_t i = 0; i < total_size; i++) + { + if(i % topk == 0) + { + unique_set.clear(); + } + current_v = std::rand() % num_expert; + while(unique_set.find(current_v) != unique_set.end()) + { + current_v = std::rand() % num_expert; + } + unique_set.insert(current_v); + host_tensor[i] = current_v; + } +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("t", "128", "num input tokens") + .insert("e", "32", "num of experts") + .insert("k", "5", "topk") + .insert("h", "8192", "hidden_size of this model") + .insert("i", "8192", "intermediate_size between 2 gemms of FFN") + .insert("stride", "-1", "stride per row, if -1 then equal to hidden_size") + .insert("bm", "32", "blocking factor for sorted tokens") + .insert("tp", "8", "tensor parallel size") + .insert("v", "1", "cpu validation or not") + .insert("kname", "1", "print kernel name or not") + .insert("prec_i", "bf16", "input precision") + .insert("prec_w", "bf16", "weight precision") + .insert("prec_o", "bf16", "output precision") + .insert("prec_st", "auto", "token scale data type. auto will set to fp32") + .insert("prec_sw", "auto", "weight scale data type. auto will set to fp32") + .insert("prec_sq", "auto", "(dynamic) smooth quant data type. auto will set to fp32") + .insert("prec_kw", "auto", "topk-weight data type. auto will set to fp32") + .insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant") + .insert( + "gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate") + .insert("api", "0", "benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm") + .insert("act", "0", "activation after first gemm. 0:gelu, 1:silu") + .insert("balance", + "0", + "if set to 1, will try balance the expert in topk-ids(convenient for testing)") + .insert("init", + "1", + "init method. 0:random stepped float(fast). 1: random uniform[-0.5, 0.5], 2:rand " + "normalized[0, 1]" + "normalized(slow)") + .insert("seed", "11939", "seed used to do random") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// I:input-type, W:weight-type, O:output-type, ST:toke-scale-tpye, SW:weight-scale-type, +// SQ:smooth-quant-type, KW:topk-weight-type +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t tokens = arg_parser.get_int("t"); + ck_tile::index_t experts = arg_parser.get_int("e"); + ck_tile::index_t topk = arg_parser.get_int("k"); + ck_tile::index_t hidden_size = arg_parser.get_int("h"); + ck_tile::index_t intermediate_size = arg_parser.get_int("i"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + ck_tile::index_t block_m = arg_parser.get_int("bm"); + ck_tile::index_t activation = arg_parser.get_int("act"); + if(stride < 0) + stride = hidden_size; + std::string prec_i = arg_parser.get_str("prec_i"); + std::string prec_w = arg_parser.get_str("prec_w"); + std::string prec_o = arg_parser.get_str("prec_o"); + std::string prec_st = arg_parser.get_str("prec_st"); + std::string prec_sw = arg_parser.get_str("prec_sw"); + std::string prec_sq = arg_parser.get_str("prec_sq"); + std::string prec_kw = arg_parser.get_str("prec_kw"); + prec_st = (prec_st == "auto") ? "fp32" : prec_st; + prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw; + prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq; + prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw; + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + int fused_quant = arg_parser.get_int("fquant"); + int gate_only = arg_parser.get_int("gate_only"); + int api = arg_parser.get_int("api"); + int balance = arg_parser.get_int("balance"); + int tp = arg_parser.get_int("tp"); + int init = arg_parser.get_int("init"); + uint32_t seed = arg_parser.get_uint32("seed"); + + // w0 (Gate+Up or Gate only, N size) + ck_tile::index_t shared_intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2) / tp; + // w1 (Down, N size) + ck_tile::index_t shared_intermediate_size_1 = intermediate_size / tp; + + auto prec_str = [&]() { + auto base_str = prec_i; + if(prec_i != prec_w) + base_str += "x" + prec_w; + if(prec_i != prec_o) + base_str += "=" + prec_o; + if(fused_quant != 0) + { + base_str += std::string("(") + prec_st + "|" + prec_sw + "|" + prec_sq + ")"; + } + return base_str; + }(); + auto api_str = [&]() { + if(api == 0) + return std::string("fmoe"); + else if(api == 1) + return std::string("moeg"); + else if(api == 2) + return std::string("moes"); + return std::string(""); + }(); + + auto stride_str = [&]() { + if(stride == hidden_size) + return std::string(""); + else + return std::string(", st:") + std::to_string(stride); + }(); + + std::cout + << "[" << api_str << "|" << prec_str << "]" + << " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str + << ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp + << ", act:" + << activation + // << ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1 + << (gate_only ? ", g1u0" : ", g1u1") << ", q:" << fused_quant << std::flush; + + using TypeConfig = FusedMoeGemmTypeConfig; + using ADataType = typename TypeConfig::ADataType; + using GDataType = typename TypeConfig::GDataType; + using DDataType = typename TypeConfig::DDataType; + using AccDataType = typename TypeConfig::AccDataType; + using ODataType = typename TypeConfig::ODataType; + using AScaleDataType = typename TypeConfig::AScaleDataType; + using GScaleDataType = typename TypeConfig::GScaleDataType; + using DScaleDataType = typename TypeConfig::DScaleDataType; + using YSmoothScaleDataType = typename TypeConfig::YSmoothScaleDataType; + using TopkWeightDataType = typename TypeConfig::TopkWeightDataType; + using IndexDataType = typename TypeConfig::IndexDataType; + + // host verify + ck_tile::HostTensor a_host({tokens, hidden_size}, {stride, 1}); + ck_tile::HostTensor g_host({experts, shared_intermediate_size_0, hidden_size}); + ck_tile::HostTensor d_host({experts, hidden_size, shared_intermediate_size_1}); + ck_tile::HostTensor o_host({tokens, hidden_size}, {stride, 1}); + ck_tile::HostTensor sa_host({tokens}); + ck_tile::HostTensor sg_host({shared_intermediate_size_0}); + ck_tile::HostTensor sd_host({shared_intermediate_size_1}); + ck_tile::HostTensor sy_host({shared_intermediate_size_1}); // smooth-quant + ck_tile::HostTensor topk_ids_host({tokens, topk}); // to be sort + ck_tile::HostTensor topk_weight_host({tokens, topk}); // to be sort + + int max_num_tokens_padded = topk * tokens + experts * block_m - topk; + ck_tile::HostTensor sorted_token_ids_host({max_num_tokens_padded}); + ck_tile::HostTensor sorted_weight_host({max_num_tokens_padded}); + ck_tile::HostTensor sorted_expert_ids_host( + {(max_num_tokens_padded + block_m - 1) / block_m}); + ck_tile::HostTensor num_sorted_tiles_host({1}); + + if(init == 0) + { + ck_tile::FillStepRange{-.5f, .5f, 0.01f}(a_host); + ck_tile::FillStepRange{-.5f, .5f, 0.01f}(g_host); + ck_tile::FillStepRange{.5f, -.5f, -0.01f}(d_host); + ck_tile::FillStepRange{0.f, 1.f, 0.01f}(sa_host); + ck_tile::FillStepRange{0.f, 1.f, 0.01f}(sg_host); + ck_tile::FillStepRange{0.f, 1.f, 0.01f}(sd_host); + ck_tile::FillStepRange{0.f, 1.f, 0.01f}(sy_host); + ck_tile::FillStepRange{-.5f, .5f, 0.01f}(topk_weight_host); + } + else if(init == 1) + { + ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(a_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(g_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(d_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(sa_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(sg_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(sd_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(sy_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}( + topk_weight_host); + } + else if(init == 2) + { + ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(a_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(g_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(d_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(sa_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(sg_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(sd_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(sy_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(topk_weight_host); + } + + // permute weight + ck_tile::HostTensor g_perm_host = shuffle_moe_weight(g_host, prec_w, 1); + ck_tile::HostTensor d_perm_host = shuffle_moe_weight(d_host, prec_w, 1); + + // do moe sorting + if(balance) + { + int e_cnt = 0; + for(int i = 0; i < static_cast(topk_ids_host.mData.size()); i++) + { + topk_ids_host.mData[i] = e_cnt; + e_cnt++; + if(e_cnt >= experts) + e_cnt = 0; + } + } + else + { + topid_unique_gen(topk_ids_host.mData, tokens, topk, experts, 11913); + } + +// leave it here for future debug purpose +#if 0 + a_host.loadtxt("../../ater/input_torch.txt"); + + topk_ids_host.loadtxt("../../ater/topk_ids_torch.txt", "int"); + // topk_ids_host.savetxt("topk_ids_2.txt"); + topk_weight_host.loadtxt("../../ater/topk_weights_torch.txt", "float"); + std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl; + + g_host.loadtxt("../../ater/w1_torch.txt", "float"); + std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl; + d_host.loadtxt("../../ater/w2_torch.txt", "float"); + std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl; + + ck_tile::HostTensor g_perm_host = shuffle_moe_weight(g_host, prec_w, 1); + std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl; + ck_tile::HostTensor d_perm_host = shuffle_moe_weight(d_host, prec_w, 1); + std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl; +#endif + +#if 0 + std::cout << "sorted_token_ids_host:" << sorted_token_ids_host << std::endl; + std::cout << "num_sorted_tiles_host:" << num_sorted_tiles_host << std::endl; + std::cout << "sorted_expert_ids_host:" << sorted_expert_ids_host << std::endl; + std::cout << "topk_weight_host:" << topk_weight_host << std::endl; + std::cout << "sorted_weight_host:" << sorted_weight_host << std::endl; +#endif + auto cal_tflops = [&](auto ms) { + double flop_gemm_0 = + 2 * static_cast(tokens) * topk * shared_intermediate_size_0 * hidden_size; + double flop_gemm_1 = + 2 * static_cast(tokens) * topk * shared_intermediate_size_1 * hidden_size; + return (flop_gemm_0 + flop_gemm_1) / (static_cast(ms) * 1e-3) / 1e12; + }; + + // TODO: this method we use expert-by-expert view, just for reference + auto cal_tbps = [&](auto ms) { + double token_bytes = + static_cast(tokens) * topk / experts * hidden_size * sizeof(ADataType); + double w0_bytes = static_cast(shared_intermediate_size_0) * experts * hidden_size * + sizeof(GDataType); + double w1_bytes = static_cast(shared_intermediate_size_1) * experts * hidden_size * + sizeof(DDataType); + double o_bytes = + static_cast(tokens) * topk / experts * hidden_size * sizeof(ODataType); + double topk_weights_bytes = static_cast(tokens) * topk * sizeof(TopkWeightDataType); + // ignore index, they are too small + + return (token_bytes + w0_bytes + w1_bytes + o_bytes + topk_weights_bytes) / + (static_cast(ms) * 1e-3) / 1e12; + }; + + if(api == 0) + { + ck_tile::DeviceMem a_buf(a_host); + ck_tile::DeviceMem g_perm_buf(g_perm_host); + ck_tile::DeviceMem d_perm_buf(d_perm_host); + ck_tile::DeviceMem sa_buf(sa_host); + ck_tile::DeviceMem sg_buf(sg_host); + ck_tile::DeviceMem sd_buf(sd_host); + ck_tile::DeviceMem sy_buf(sy_host); + ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); + + ck_tile::DeviceMem topk_ids_buf(topk_ids_host); + ck_tile::DeviceMem topk_weight_buf(topk_weight_host); + + ck_tile::DeviceMem sorted_token_ids_buf( + sorted_token_ids_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sorted_expert_ids_buf( + sorted_expert_ids_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem num_sorted_tiles_buf( + num_sorted_tiles_host.get_element_space_size_in_bytes()); + + fused_moe_traits traits{prec_i, + prec_w, + prec_o, + prec_st, + prec_sw, + prec_sq, + prec_kw, + block_m, + activation, + gate_only, + fused_quant}; + + fused_moe_args args{a_buf.GetDeviceBuffer(), + fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr, + g_perm_buf.GetDeviceBuffer(), + d_perm_buf.GetDeviceBuffer(), + fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr, + fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr, + fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr, + o_buf.GetDeviceBuffer(), + topk_ids_buf.GetDeviceBuffer(), + topk_weight_buf.GetDeviceBuffer(), + sorted_token_ids_buf.GetDeviceBuffer(), + sorted_weight_buf.GetDeviceBuffer(), + sorted_expert_ids_buf.GetDeviceBuffer(), + num_sorted_tiles_buf.GetDeviceBuffer(), + block_m, + hidden_size, + intermediate_size / tp, + tokens, + experts, + topk, + stride}; + float ave_time = fused_moe( + traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); + + if(ave_time < 0) + { + std::cout << " not supported!" << std::endl << std::flush; + return false; + } + + // float gb_per_sec = num_byte / 1.E6 / ave_time; + std::cout << ", " << ave_time * 1.E3 << " us, " << cal_tflops(ave_time) << " tflops, " + << cal_tbps(ave_time) << " TB/s" << std::flush; + bool pass = true; + +#define CPU_FUSED_MOE(act_type_) \ + ck_tile::reference_fused_moe(a_host, \ + g_host, \ + d_host, \ + sa_host, \ + sg_host, \ + sd_host, \ + sy_host, \ + o_host, \ + sorted_token_ids_host, \ + sorted_weight_host, \ + sorted_expert_ids_host, \ + num_sorted_tiles_host, \ + topk_ids_host, \ + block_m, \ + tokens, \ + experts, \ + hidden_size, \ + intermediate_size / tp, \ + topk, \ + gate_only) + + if(do_validation) + { + ck_tile::reference_moe_sorting( + topk_ids_host, + topk_weight_host, + sorted_token_ids_host, + sorted_weight_host, + sorted_expert_ids_host, + num_sorted_tiles_host.mData[0], + experts, + block_m); + if(activation == 0) + { + CPU_FUSED_MOE(ck_tile::element_wise::Gelu); + } + else + { + CPU_FUSED_MOE(ck_tile::element_wise::Silu); + } + + auto o_dev = o_buf.ToHost(); + // o_dev.savetxt("gpu-out.txt", "float"); + auto [rtol, atol] = get_elimit(); + pass &= ck_tile::check_err( + o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol); + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; + } + std::cout << std::flush << std::endl; + return pass; + } + else if(api == 1) + { + ck_tile::reference_moe_sorting( + topk_ids_host, + topk_weight_host, + sorted_token_ids_host, + sorted_weight_host, + sorted_expert_ids_host, + num_sorted_tiles_host.mData[0], + experts, + block_m); + + // done, preparing GPU buffer + ck_tile::DeviceMem a_buf(a_host); + ck_tile::DeviceMem g_perm_buf(g_perm_host); + ck_tile::DeviceMem d_perm_buf(d_perm_host); + ck_tile::DeviceMem sa_buf(sa_host); + ck_tile::DeviceMem sg_buf(sg_host); + ck_tile::DeviceMem sd_buf(sd_host); + ck_tile::DeviceMem sy_buf(sy_host); + ck_tile::DeviceMem o_buf(o_host); + + // manually clear output buffer for atomic + o_buf.SetZero(); + // + + ck_tile::DeviceMem sorted_token_ids_buf(sorted_token_ids_host); + ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host); + ck_tile::DeviceMem sorted_expert_ids_buf(sorted_expert_ids_host); + ck_tile::DeviceMem num_sorted_tiles_buf(num_sorted_tiles_host); + + fused_moegemm_traits traits{prec_i, + prec_w, + prec_o, + prec_st, + prec_sw, + prec_sq, + prec_kw, + block_m, + activation, + gate_only, + fused_quant}; + + fused_moegemm_args args{a_buf.GetDeviceBuffer(), + fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr, + g_perm_buf.GetDeviceBuffer(), + d_perm_buf.GetDeviceBuffer(), + fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr, + fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr, + fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr, + o_buf.GetDeviceBuffer(), + sorted_token_ids_buf.GetDeviceBuffer(), + sorted_weight_buf.GetDeviceBuffer(), + sorted_expert_ids_buf.GetDeviceBuffer(), + num_sorted_tiles_buf.GetDeviceBuffer(), + hidden_size, + intermediate_size / tp, + tokens, + experts, + topk, + stride}; + + float ave_time = fused_moegemm( + traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); + + if(ave_time < 0) + { + std::cout << " not supported!" << std::endl << std::flush; + return false; + } + + // float gb_per_sec = num_byte / 1.E6 / ave_time; + std::cout << ", " << ave_time * 1.E3 << " us, " << cal_tflops(ave_time) << " tflops, " + << cal_tbps(ave_time) << " TB/s" << std::flush; + bool pass = true; + + if(do_validation) + { + if(activation == 0) + { + CPU_FUSED_MOE(ck_tile::element_wise::Gelu); + } + else + { + CPU_FUSED_MOE(ck_tile::element_wise::Silu); + } + + auto o_dev = o_buf.ToHost(); + // o_dev.savetxt("gpu-out.txt", "float"); + auto [rtol, atol] = get_elimit(); + pass &= ck_tile::check_err( + o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol); + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; + } + std::cout << std::flush << std::endl; + + return pass; + } + return false; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + std::string prec_i = arg_parser.get_str("prec_i"); + std::string prec_w = arg_parser.get_str("prec_w"); + std::string prec_o = arg_parser.get_str("prec_o"); + std::string prec_st = arg_parser.get_str("prec_st"); + std::string prec_sw = arg_parser.get_str("prec_sw"); + std::string prec_sq = arg_parser.get_str("prec_sq"); + std::string prec_kw = arg_parser.get_str("prec_kw"); + prec_st = (prec_st == "auto") ? "fp32" : prec_st; + prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw; + prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq; + prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw; + + // no dynamic quant case + if(prec_i == "bf16" && prec_w == "bf16" && prec_o == "bf16" && prec_kw == "fp32") + { + return run( + arg_parser) + ? 0 + : -2; + } + else if(prec_i == "fp16" && prec_w == "fp16" && prec_o == "fp16" && prec_kw == "fp32") + { + return run( + arg_parser) + ? 0 + : -2; + } + + return -3; +} diff --git a/example/ck_tile/15_fused_moe/misc/moe-0.png b/example/ck_tile/15_fused_moe/misc/moe-0.png new file mode 100644 index 0000000000000000000000000000000000000000..aed1964f2802c4e7f65d7080f338309c8c2287a6 Binary files /dev/null and b/example/ck_tile/15_fused_moe/misc/moe-0.png differ diff --git a/example/ck_tile/15_fused_moe/misc/moe-1.png b/example/ck_tile/15_fused_moe/misc/moe-1.png new file mode 100644 index 0000000000000000000000000000000000000000..91a1f2d9dde2eb892ab621bb1fdaa9e1f7f23a8a Binary files /dev/null and b/example/ck_tile/15_fused_moe/misc/moe-1.png differ diff --git a/example/ck_tile/15_fused_moe/misc/moe-2.png b/example/ck_tile/15_fused_moe/misc/moe-2.png new file mode 100644 index 0000000000000000000000000000000000000000..98d83866fad9925583db583e5179f139202cf612 Binary files /dev/null and b/example/ck_tile/15_fused_moe/misc/moe-2.png differ diff --git a/example/ck_tile/15_fused_moe/misc/moe-3.png b/example/ck_tile/15_fused_moe/misc/moe-3.png new file mode 100644 index 0000000000000000000000000000000000000000..77c6d9b6e43ea2c2ef9087eadff6028b6af3f113 Binary files /dev/null and b/example/ck_tile/15_fused_moe/misc/moe-3.png differ diff --git a/example/ck_tile/16_batched_gemm/CMakeLists.txt b/example/ck_tile/16_batched_gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..78e78c6b0458fe034d2d4be936a4275124edb8c5 --- /dev/null +++ b/example/ck_tile/16_batched_gemm/CMakeLists.txt @@ -0,0 +1 @@ +add_executable(tile_example_batched_gemm EXCLUDE_FROM_ALL batched_gemm.cpp) diff --git a/example/ck_tile/16_batched_gemm/README.md b/example/ck_tile/16_batched_gemm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..34b56db526b10c183d023f763c5172682c6d36b6 --- /dev/null +++ b/example/ck_tile/16_batched_gemm/README.md @@ -0,0 +1,37 @@ +# Batched GEMM + +This folder contains example for batched GEMM using ck_tile tile-programming implementation. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ +make tile_example_batched_gemm -j +``` +This will result in an executable `build/bin/tile_example_batched_gemm` + +## example +``` +args: + -m m dimension (default:256) + -n n dimension (default:128) + -k k dimension (default:128) + -a_layout A tensor data layout (default:R) (R for Row, C for Col) + -b_layout B tensor data layout (default:R) (R for Row, C for Col) + -c_layout C tensor data layout (default:R) (R for Row, C for Col) + -stride_a Tensor A stride (default:128) + -stride_b Tensor B stride (default:128) + -stride_c Tensor C stride (default:128) + -batch_stride_a Batch A stride (default:32768) + -batch_stride_b Batch B stride (default:16384) + -batch_stride_c Batch C stride (default:32768) + -batch_count Batch count (default:16) + -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2) + -e Absolute error tolerance (default:1e-5) + -prec data type. fp16/bf16/fp8/bf8 (default:fp16) + -warmup number of iterations before benchmark the kernel (default:10) + -repeat number of iterations to benchmark the kernel (default:100) + -timer gpu:gpu timer, cpu:cpu timer (default:gpu) +``` \ No newline at end of file diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..949621e116dafd504e3849628e3bf3fcbde33483 --- /dev/null +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/host.hpp" +#include "batched_gemm.hpp" + +template +float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s) +{ + // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr int kBlockPerCu = 1; + + // This part comes from the Codegen + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 8; + + using CodegenGemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + + using CodegenGemmTraits = + ck_tile::TileGemmTraits; + using CodegenPipelineProblem = ck_tile:: + GemmPipelineProblem; + using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + // ToDo: Will add the codegen part to test different pipeline policies in GEMM. + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = ck_tile::BatchedGemmKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; +} + +#include "run_batched_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.hpp b/example/ck_tile/16_batched_gemm/batched_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7b7e22160a2e24c5c3e71ca53028e8c9c393a7e5 --- /dev/null +++ b/example/ck_tile/16_batched_gemm/batched_gemm.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" + +template +struct BatchedGemmTypeConfig; + +template <> +struct BatchedGemmTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +using Types = BatchedGemmTypeConfig; + +// Specific type aliases for easy access +using ADataType = Types::ADataType; +using BDataType = Types::BDataType; +using AccDataType = Types::AccDataType; +using CDataType = Types::CDataType; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "256", "m dimension") + .insert("n", "128", "n dimension") + .insert("k", "128", "k dimension") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Row by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("batch_stride_a", "32768", "Batch A stride") + .insert("batch_stride_b", "16384", "Batch B stride") + .insert("batch_stride_c", "32768", "Batch C stride") + .insert("batch_count", "16", "Batch count") + .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// host API +float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..d0df8845cc37ee8c29efccb146338a4cea3fd61b --- /dev/null +++ b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc @@ -0,0 +1,325 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t batch_stride_A, + ck_tile::index_t batch_stride_B, + ck_tile::index_t batch_stride_C, + ck_tile::index_t batch_count, + ck_tile::index_t kbatch, + int n_warmup, + int n_repeat) +{ + ck_tile::BatchedGemmHostArgs args; + args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); + args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); + args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + args.k_batch = kbatch; + args.M = M; + args.N = N; + args.K = K; + args.stride_A = stride_A; + args.stride_B = stride_B; + args.stride_C = stride_C; + args.batch_stride_A = batch_stride_A; + args.batch_stride_B = batch_stride_B; + args.batch_stride_C = batch_stride_C; + args.batch_count = batch_count; + + float ave_time = batched_gemm( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::string op_name{"Batched Gemm"}; + std::size_t flop = std::size_t(2) * batch_count * M * N * K; + std::size_t num_byte = sizeof(ADataType) * batch_count * M * K + + sizeof(BDataType) * batch_count * N * K + + sizeof(CDataType) * batch_count * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run " << op_name << "kernel with M =" << M << " N =" << N << " K =" << K + << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C + << " batch_stride_A =" << batch_stride_A << " batch_stride_B =" << batch_stride_B + << " batch_stride_C =" << batch_stride_C << " batch_count =" << batch_count << " : " + << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << std::endl; + + return ave_time; +} + +template +int run_batched_gemm_example_with_layouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); + + ck_tile::index_t batch_stride_A = arg_parser.get_int("batch_stride_a"); + ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b"); + ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c"); + ck_tile::index_t batch_count = arg_parser.get_int("batch_count"); + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + + using namespace ck_tile::literals; + + auto f_host_tensor_descriptor = [](std::size_t batch_count_, + std::size_t row, + std::size_t col, + std::size_t stride, + std::size_t batch_stride, + auto layout) { + if constexpr(std::is_same_v) + { + return ck_tile::HostTensorDescriptor({batch_count_, row, col}, + {batch_stride, stride, 1_uz}); + } + else + { + return ck_tile::HostTensorDescriptor({batch_count_, row, col}, + {batch_stride, 1_uz, stride}); + } + }; + + auto f_get_default_stride = [](std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if(stride == 0) + { + // give a chance if stride is zero, return a default packed stride + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + stride_A = f_get_default_stride(M, K, stride_A, a_layout); + stride_B = f_get_default_stride(K, N, stride_B, b_layout); + stride_C = f_get_default_stride(M, N, stride_C, c_layout); + + ck_tile::HostTensor a_m_k( + f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, a_layout)); + ck_tile::HostTensor b_k_n( + f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, b_layout)); + ck_tile::HostTensor c_m_n_dev_result( + f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, c_layout)); + + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + invoke_batched_gemm(a_m_k_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_count, + kbatch, + n_warmup, + n_repeat); + + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + bool pass = true; + + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor c_m_n_host_ref( + f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{})); + c_m_n_host_ref.SetZero(); + + const auto b_n_k = b_k_n.transpose({0, 2, 1}); + + ck_tile::reference_batched_gemm( + a_m_k, b_n_k, c_m_n_host_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + + std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; + } + else if(arg_parser.get_int("v") == 2) + { + ck_tile::HostTensor c_m_n_gpu_ref( + f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{})); + ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); + c_m_n_gpu_ref.SetZero(); + c_m_n_gpu_buf_ref.SetZero(); + + ADataType* d_A; + BDataType* d_B; + CDataType* d_C; + + ck_tile::hip_check_error(hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType))); + ck_tile::hip_check_error(hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType))); + ck_tile::hip_check_error(hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType))); + + ck_tile::hip_check_error(hipMemcpy(d_A, + a_m_k_dev_buf.GetDeviceBuffer(), + batch_count * M * K * sizeof(ADataType), + hipMemcpyHostToDevice)); + + ck_tile::hip_check_error(hipMemcpy(d_B, + b_k_n_dev_buf.GetDeviceBuffer(), + batch_count * N * K * sizeof(BDataType), + hipMemcpyHostToDevice)); + + ck_tile::reference_batched_gemm_gpu(d_A, + d_B, + d_C, + M, + N, + K, + stride_A, + stride_B, + stride_C, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_count); + + ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(), + d_C, + batch_count * M * N * sizeof(CDataType), + hipMemcpyDeviceToHost)); + + ck_tile::hip_check_error(hipFree(d_A)); + ck_tile::hip_check_error(hipFree(d_B)); + ck_tile::hip_check_error(hipFree(d_C)); + + c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); + const float max_accumulated_value = + *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_gpu_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + + std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl; + } + + return pass; +} + +int run_batched_gemm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + // if(a_layout == "R" && b_layout == "R") + // { + // return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + // } + if(a_layout == "R" && b_layout == "C") + { + return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + // TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not + // work else if(a_layout == "C" && b_layout == "C") + // { + // return run_batched_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + // } + // else if(a_layout == "C" && b_layout == "R") + // { + // return run_batched_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + // } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } +} diff --git a/example/ck_tile/17_grouped_gemm/CMakeLists.txt b/example/ck_tile/17_grouped_gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d34013dd6c926a9d95c66ebaed5b1849a2484114 --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp) + diff --git a/example/ck_tile/17_grouped_gemm/README.md b/example/ck_tile/17_grouped_gemm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d1a0458eda6bbbc0bacb8c83676d038fc829c7ea --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/README.md @@ -0,0 +1,25 @@ +# Grouped CShuffle GEMM + +This folder contains example for Grouped GEMM using ck_tile tile-programming implementation. Currently, it only supports the basic feature of the CK Tile GEMM, but creates the placeholders for the future support on different GEMM pipeline and different GEMM modules. In the near future, we will gradually migrate all the GEMM features from old CK to CK Tile. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ +# The basic pipeline method on the gemm calculation +make tile_example_grouped_gemm -j +``` +This will result in an executable `build/bin/tile_example_grouped_gemm` + +## example +``` +args: + -a_layout Tensor A layout (default:R) + -b_layout Tensor B layout (default:R) + -c_layout Tensor C layout (default:R) + -v 0. No validation, 1. Validation on CPU + -warmup number of iterations before benchmark the kernel (default:10) + -repeat number of iterations to benchmark the kernel (default:100) +``` diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c32fac6c0d751ac8a9c48daca4312585d33073ac --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -0,0 +1,141 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/host.hpp" +#include "grouped_gemm.hpp" + +namespace { + +struct GroupedGemmKernelParam +{ + static const bool kPadM = false; + static const bool kPadN = false; + static const bool kPadK = false; + + static const int kBlockPerCu = 1; + static const ck_tile::index_t M_Tile = 128; + static const ck_tile::index_t N_Tile = 128; + static const ck_tile::index_t K_Tile = 32; + + static const ck_tile::index_t M_Warp = 2; + static const ck_tile::index_t N_Warp = 2; + static const ck_tile::index_t K_Warp = 1; + + static const ck_tile::index_t M_Warp_Tile = 32; + static const ck_tile::index_t N_Warp_Tile = 32; + static const ck_tile::index_t K_Warp_Tile = 8; +}; + +using CodegenGemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + +using TilePartitioner = ck_tile::GemmTile1DPartitioner; + +template +using CodegenGemmTraits = ck_tile::TileGemmTraits; + +template +using CodegenPipelineProblem = + ck_tile::GemmPipelineProblem>; + +template +using CodegenGemmPipeline = + ck_tile::GemmPipelineAGmemBGmemCRegV1>; + +template +using GemmEpilogue = ck_tile::CShuffleEpilogue::kBlockSize, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GroupedGemmKernelParam::M_Warp, + GroupedGemmKernelParam::N_Warp, + GroupedGemmKernelParam::M_Warp_Tile, + GroupedGemmKernelParam::N_Warp_Tile, + GroupedGemmKernelParam::K_Warp_Tile, + CodegenPipelineProblem::TransposeC>>; + +template +using Kernel = ck_tile::GroupedGemmKernel, + GemmEpilogue>; +}; // namespace + +std::size_t get_workspace_size(const std::vector& gemm_descs) +{ + return ::Kernel::GetWorkSpaceSize(gemm_descs); +} + +template +float grouped_gemm(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* p_workspace_) +{ + using GroupedGemmKernel = ::Kernel; + + auto arguments = GroupedGemmKernel::MakeKargs(gemm_descs); + + const dim3 grids = GroupedGemmKernel::GridSize(gemm_descs); + constexpr dim3 blocks = GroupedGemmKernel::BlockSize(); + + ck_tile::hip_check_error(hipMemcpyWithStream( + p_workspace_, + arguments.data(), + arguments.size() * sizeof(typename GroupedGemmKernel::GemmTransKernelArg), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + float ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + GroupedGemmKernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(p_workspace_), + gemm_descs.size())); + return ave_time; +} + +#include "run_grouped_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2ffef95196c79bb16620d9f45364ddc53e00fe02 --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" + +template +struct GemmBasicTypeConfig; + +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using CDataType = ck_tile::half_t; + using AccDataType = float; +}; + +using Types = GemmBasicTypeConfig; + +// Specific type aliases for easy access +using ADataType = Types::ADataType; +using BDataType = Types::BDataType; +using AccDataType = Types::AccDataType; +using CDataType = Types::CDataType; + +using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("Ms", "", "M dimensions - empty by default.") + .insert("Ns", "", "N dimensions - empty by default.") + .insert("Ks", "", "K dimensions - empty by default.") + .insert("stride_As", "", "Tensor A strides - it is empty by default.") + .insert("stride_Bs", "", "Tensor B strides - it is empty by default.") + .insert("stride_Cs", "", "Tensor C strides - it is empty by default.") + .insert("a_layout", "R", "A tensor data layout - Row by default.") + .insert("b_layout", "C", "B tensor data layout - Row by default.") + .insert("c_layout", "R", "C tensor data layout - Row by default.") + .insert("validate", "1", "0. No validation, 1. Validation on CPU.") + .insert("warmup", "10", "number of iterations before benchmark the kernel.") + .insert("repeat", "100", "number of iterations to benchmark the kernel.") + .insert("group_count", "16", "group count."); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +std::size_t get_workspace_size(const std::vector& gemm_descs); + +float grouped_gemm(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* p_workspace_); diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..b0a3e9973c247cd5306d941461838e82decf63ae --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -0,0 +1,237 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +float invoke_gemm(int n_warmup, + int n_repeat, + int group_count, + const std::vector& args) +{ + + ck_tile::DeviceMem gemm_workspace; + gemm_workspace.Realloc(get_workspace_size(args)); + + float ave_time = grouped_gemm( + args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, + gemm_workspace.GetDeviceBuffer()); + + std::string op_name{"Grouped Gemm"}; + + std::size_t flop = 0, num_btype = 0; + for(int j = 0; j < group_count; ++j) + { + flop += std::size_t(2) * args[j].M * args[j].N * args[j].K; + + num_btype += sizeof(ADataType) * args[j].M * args[j].K + + sizeof(BDataType) * args[j].K * args[j].N + + sizeof(CDataType) * args[j].M * args[j].N; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + return ave_time; +} + +template +int run_grouped_gemm_example_with_layouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + + if(!result) + { + return -1; + }; + + auto valid_input_data = [&](int group_count, const auto&... args) { + return !(args.empty() || ...) && group_count == (args.size() == ...); + }; + + const int group_count = arg_parser.get_int("group_count"); + const int repeat = arg_parser.get_int("repeat"); + const int warmup = arg_parser.get_int("warmup"); + + std::vector Ms = arg_parser.get_int_vec("Ms"); + std::vector Ns = arg_parser.get_int_vec("Ns"); + std::vector Ks = arg_parser.get_int_vec("Ks"); + std::vector stride_As = arg_parser.get_int_vec("stride_As"); + std::vector stride_Bs = arg_parser.get_int_vec("stride_Bs"); + std::vector stride_Cs = arg_parser.get_int_vec("stride_Cs"); + + if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs)) + { + std::cout << "Please check the input data. Default values will be used." << std::endl; + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(128 + 128 * i); + Ks.push_back(128 + 64 * i); + + stride_As.push_back(Ks[i]); + stride_Bs.push_back(Ks[i]); + stride_Cs.push_back(Ns[i]); + } + } + + std::vector> a_m_k_tensors; + std::vector> b_k_n_tensors; + std::vector> c_m_n_tensors; + + a_m_k_tensors.reserve(group_count); + b_k_n_tensors.reserve(group_count); + c_m_n_tensors.reserve(group_count); + + std::vector> a_m_k_dev_buf; + std::vector> b_k_n_dev_buf; + std::vector> c_m_n_dev_buf; + + a_m_k_dev_buf.reserve(group_count); + b_k_n_dev_buf.reserve(group_count); + c_m_n_dev_buf.reserve(group_count); + + std::vector gemm_descs; + gemm_descs.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + const ck_tile::index_t M = Ms[i]; + const ck_tile::index_t N = Ns[i]; + const ck_tile::index_t K = Ks[i]; + + stride_As[i] = ck_tile::get_default_stride(M, N, stride_As[i], is_row_major(a_layout)); + stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout)); + stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{})); + + a_m_k_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout)))); + b_k_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout)))); + c_m_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{})))); + + std::cout << "gemm[" << i << "]" + << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc + << " c_m_n: " << c_m_n_tensors[i].mDesc << std::endl; + + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n_tensors[i]); + + a_m_k_dev_buf.push_back(std::make_unique( + a_m_k_tensors[i].get_element_space_size_in_bytes())); + b_k_n_dev_buf.push_back(std::make_unique( + b_k_n_tensors[i].get_element_space_size_in_bytes())); + c_m_n_dev_buf.push_back(std::make_unique( + c_m_n_tensors[i].get_element_space_size_in_bytes())); + + a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data()); + b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data()); + c_m_n_dev_buf[i]->SetZero(); + c_m_n_tensors[i].SetZero(); + + const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer(); + const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); + void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); + + gemm_descs.push_back({p_a, p_b, p_c, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); + } + + invoke_gemm(warmup, repeat, group_count, gemm_descs); + + for(int i = 0; i < group_count; i++) + { + c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data()); + } + + bool pass{true}; + if(arg_parser.get_int("validate")) + { + for(int i = 0; i < group_count; ++i) + { + ck_tile::HostTensor c_m_n_host_ref(ck_tile::host_tensor_descriptor( + Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + ck_tile::reference_gemm( + a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol(Ks[i], 1 /*kbatch*/, max_accumulated_value); + pass &= ck_tile::check_err(c_m_n_tensors[i], + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "gemm[" << i + << "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; + } + + return pass; +} + +int run_grouped_gemm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + return -1; + } + + const std::string a_layout = arg_parser.get_str("a_layout"); + const std::string b_layout = arg_parser.get_str("b_layout"); + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + if(a_layout == "R" && b_layout == "C") + { + return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + // else if(a_layout == "R" && b_layout == "R") + // { + // return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + // } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } +} diff --git a/example/ck_tile/35_batched_transpose/CMakeLists.txt b/example/ck_tile/35_batched_transpose/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a08fcebb74f3828bc28f5d7cab71c4bad9344711 --- /dev/null +++ b/example/ck_tile/35_batched_transpose/CMakeLists.txt @@ -0,0 +1,9 @@ +set(TARGET_NAME tile_example_batched_transpose) +add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL batched_transpose_example.cpp batched_transpose_api.cpp) +target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +# list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) +target_compile_options(tile_example_batched_transpose PRIVATE ${EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS}) + diff --git a/example/ck_tile/35_batched_transpose/README.md b/example/ck_tile/35_batched_transpose/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d0583e75292c765f3eb8d81a41b72a9394a63105 --- /dev/null +++ b/example/ck_tile/35_batched_transpose/README.md @@ -0,0 +1,27 @@ +# Batched Transpose +This folder contains example for batched Transpose using ck_tile tile-programming implementation. Currently, it supports the batched transpose with NCHW to NHWC or NHWC to NCHW. So in this way from NCHW you could transpose to either NHWC or NWCH(two transposes). Now the transpose read with single data point. We would soon put it in vectorized transpose. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ +# Make the transpose executable +make tile_example_batched_transpose -j +``` +This will result in an executable `build/bin/tile_example_batched_transpose` + +## example +``` +args: + -N input batch size (default:2) + -C input channel size. (default:16) + -H input height size. (default:1) + -W input width size. (default:16) + -v whether do CPU validation or not (default: 1) + -layout_in input tensor data layout - NCHW by default + -layout_out output tensor data layout - NHWC by default + -seed seed to be used, -1 means random every time (default:-1) + -k_name t to 1 will print kernel name (default:0) +``` \ No newline at end of file diff --git a/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp b/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp new file mode 100644 index 0000000000000000000000000000000000000000..77d768fe3fa67b6709c3f1edb05a734debdb79ab --- /dev/null +++ b/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "batched_transpose_example.hpp" +#include + +template +float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_config& s) +{ + uint32_t dim_block_h = (a.height + block_y - 1) / block_y; + uint32_t dim_block_w = (a.width + block_x - 1) / block_x; + uint32_t dim_stride = a.height * a.width; + + a.dim_stride = dim_stride; + a.dim_block_h = dim_block_h; + a.dim_block_w = dim_block_w; + + using block_tile = ck_tile::sequence; + using warp_tile = ck_tile::sequence; + using thread_tile = ck_tile::sequence; + + using ts_problem = + ck_tile::BatchedTransposeProblem; + using ts_pipeline = ck_tile::BatchedTransposePipeline; + + using kernel = ck_tile::BatchedTransposeKernel; + + auto kargs = kernel::MakeKargs(a); + + const dim3 grids = kernel::GridSize(a); + constexpr dim3 blocks = kernel::BlockSize(); + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); + + return ave_time; +} + +// Param Comb: type_size, block_x & y, warp_x & y, thread_x & y +#define FOREACH_TRANSPOSE_PARAM(F) \ + F(fp16, ck_tile::fp16_t, 16, 16, 8, 8, 1, 1) \ + F(bf16, ck_tile::bf16_t, 16, 16, 8, 8, 1, 1) \ + F(fp32, ck_tile::fp32_t, 16, 16, 8, 8, 1, 1) \ + F(int8, ck_tile::int8_t, 16, 16, 8, 8, 1, 1) + +// Macro that defines one static function per line +#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, TX, TY) \ + static float transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##TX##_##TY( \ + batched_transpose_kargs& a, ck_tile::stream_config& s) \ + { \ + return batched_transpose_dispatch(a, s); \ + } + +FOREACH_TRANSPOSE_PARAM(GEN_TRANSPOSE_FN) + +float batched_transpose(batched_transpose_trait t, + batched_transpose_kargs a, + ck_tile::stream_config s) +{ + if(t.type == "fp16") + { + return transpose_fn_fp16_16_16_8_8_1_1(a, s); + } + else if(t.type == "bf16") + { + return transpose_fn_bf16_16_16_8_8_1_1(a, s); + } + else if(t.type == "fp32") + { + return transpose_fn_fp32_16_16_8_8_1_1(a, s); + } + else if(t.type == "int8") + { + return transpose_fn_int8_16_16_8_8_1_1(a, s); + } + return -1; +} diff --git a/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp b/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp new file mode 100644 index 0000000000000000000000000000000000000000..48fc2859bfb4eaadcb9257c69f5fe069b9640cb9 --- /dev/null +++ b/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp @@ -0,0 +1,261 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "batched_transpose_example.hpp" + +#if 0 +template +void dump_host_tensor_4d(const ck_tile::HostTensor& x) +{ + auto len = x.get_lengths(); + assert(len.size() == 4); + std::cout << "["; + for(size_t i = 0; i < len[0]; i++) + { + std::cout << i << ": ["; + for(size_t j = 0; j < len[1]; j++) + { + std::cout << j << ": ["; + for(size_t k = 0; k < len[2]; k++) + { + std::cout << k << ": ["; + for(size_t v = 0; v < len[3]; v++) + { + if constexpr(std::is_same_v) + { + auto m = + ck_tile::type_convert(x(std::vector{i, j, k, v})); + + std::cout << m; + if(v != len[3] - 1) + std::cout << ","; + } + else + { + std::cout << x(std::vector{i, j, k, v}) << " "; + } + } + std::cout << "]" << std::endl; + } + std::cout << "]" << std::endl; + } + std::cout << std::endl; + } + std::cout << "--------------------" << std::endl; +} +#endif + +// different threshold for different dtype +template +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-3; + double atol = 1e-3; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string init_method) +{ + if(init_method == "ui" || init_method == "ni") + { + unsigned max_rounding_point_distance = 0; + double atol = 2e-3; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } + else + { + unsigned max_rounding_point_distance = 1; + double atol = 0.0625; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "whether do CPU validation or not") + .insert("pr", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)") + .insert("N", "2", "input batch size. ") + .insert("C", "16", "input channel size.") + .insert("H", "1", "input height size.") + .insert("W", "16", "input width size. ") + .insert("layout_in", "NCHW", "input tensor data layout - NCHW by default") + .insert("layout_out", "NHWC", "output tensor data layout - NHWC by default ") + .insert("seed", "-1", "seed to be used, -1 means random every time") + .insert("kname", "0", "t to 1 will print kernel name"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run_batched_transpose(ck_tile::ArgParser args) +{ + int validate = args.get_int("v"); + std::string prec = args.get_str("pr"); + int N = args.get_int("N"); + int C = args.get_int("C"); + int H = args.get_int("H"); + int W = args.get_int("W"); + std::string layout_in = args.get_str("layout_in"); + std::string layout_out = args.get_str("layout_out"); + int seed = args.get_int("seed"); + + int dim_in[4], dim_out[4]; + int stride_dim_in[4], stride_dim_out[4]; + bool nchw2nhwc = layout_in == "NCHW" && layout_out == "NHWC"; + bool nhwc2nchw = layout_in == "NHWC" && layout_out == "NCHW"; + assert(nchw2nhwc != nhwc2nchw); + (void)nhwc2nchw; + + dim_in[0] = N; + dim_in[1] = nchw2nhwc ? C : H; + dim_in[2] = nchw2nhwc ? H : W; + dim_in[3] = nchw2nhwc ? W : C; + dim_out[0] = N; + dim_out[1] = nchw2nhwc ? H : C; + dim_out[2] = nchw2nhwc ? W : H; + dim_out[3] = nchw2nhwc ? C : W; + stride_dim_in[0] = C * H * W; + stride_dim_in[1] = nchw2nhwc ? H * W : C * W; + stride_dim_in[2] = nchw2nhwc ? W : C; + stride_dim_in[3] = 1; + stride_dim_out[0] = C * H * W; + stride_dim_out[1] = nchw2nhwc ? C * W : H * W; + stride_dim_out[2] = nchw2nhwc ? C : W; + stride_dim_out[3] = 1; + + if(seed < 0) + { + seed = std::time(nullptr); + } + + ck_tile::HostTensor x_host( + {dim_in[0], dim_in[1], dim_in[2], dim_in[3]}, + {stride_dim_in[0], stride_dim_in[1], stride_dim_in[2], stride_dim_in[3]}); + ck_tile::HostTensor y_host( + {dim_out[0], dim_out[1], dim_out[2], dim_out[3]}, + {stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); + + ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_dev(y_host.get_element_space_size_in_bytes()); + + x_dev.ToDevice(x_host.data()); + + auto trait = batched_transpose_trait{prec, layout_in}; + + uint32_t height = nchw2nhwc ? C : H * W; + uint32_t width = nchw2nhwc ? H * W : C; + + batched_transpose_kargs karg = [&]() { + batched_transpose_kargs a_; + a_.p_input = x_dev.GetDeviceBuffer(); + a_.p_output = y_dev.GetDeviceBuffer(); + a_.batch = N; + a_.height = height; + a_.width = width; + return a_; + }(); + + ck_tile::stream_config sc{nullptr, true}; + + auto ms = batched_transpose(trait, karg, sc); + + std::size_t num_operations = N * C * H * (W - 1); + std::size_t num_bytes = N * C * H * W * sizeof(Type); + + float ave_time = ms * 1E-3; + float gb_per_sec = num_bytes / ms * 1.E-6; + float tflops = static_cast(num_operations) / ms * 1.E-6; + + std::cout << "Run Batched Transpose kernel with N=" << N << ", C=" << C << ", H=" << H + << ", W=" << W << ", layout_in=" << layout_in << ", layout_out=" << layout_out + << " : " << ms << " ms (" << ave_time << " ave_time), " << tflops << " TFlops" + << gb_per_sec << " GB/s, " << std::endl; + + printf("[%s]N:%d, C:%d, H:%d, W:%d, layout_in:%s, %f\n", + prec.c_str(), + N, + C, + H, + W, + layout_in.c_str(), + ms); + if(ms < 0) + printf("not supported\n"); + fflush(stdout); + + if(ms < 0) + { + return false; + } + + y_dev.FromDevice(y_host.data()); + + bool rtn = true; + if(validate) + { + // this host buffer will not copy to GPU, so no need use stride + ck_tile::HostTensor y_ref( + {dim_out[0], dim_out[1], dim_out[2], dim_out[3]}, + {stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]}); + + ck_tile::reference_batched_transpose(x_host, y_ref, layout_in, layout_out); + + auto [rtol, atol] = get_elimit(""); + + rtn &= ck_tile::check_err( + y_host, y_ref, std::string("y Error: Incorrect results!"), rtol, atol); + } + printf("valid:%s\n", rtn ? "y" : "n"); + fflush(stdout); + return rtn; +} + +int main(int argc, char** argv) +{ + auto [result, args] = create_args(argc, argv); + if(!result) + return -1; + std::string prec = args.get_str("pr"); + + bool r = true; + if(prec.compare("fp32") == 0) + { + r &= run_batched_transpose(args); + } + else if(prec.compare("fp16") == 0) + { + r &= run_batched_transpose(args); + } + else if(prec.compare("bf16") == 0) + { + r &= run_batched_transpose(args); + } + else if(prec.compare("int8") == 0) + { + r &= run_batched_transpose(args); + } + + return r ? 0 : -1; +} diff --git a/example/ck_tile/35_batched_transpose/batched_transpose_example.hpp b/example/ck_tile/35_batched_transpose/batched_transpose_example.hpp new file mode 100644 index 0000000000000000000000000000000000000000..487ddc17b227d8db8448cd8b363037598196f128 --- /dev/null +++ b/example/ck_tile/35_batched_transpose/batched_transpose_example.hpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "ck_tile/ops/batched_transpose.hpp" + +#include +#include + +#pragma once + +struct batched_transpose_trait +{ + std::string type; + std::string layout; +}; + +struct batched_transpose_kargs : public ck_tile::BatchedTransposeHostArgs +{ +}; + +float batched_transpose(batched_transpose_trait t, + batched_transpose_kargs a, + ck_tile::stream_config s); diff --git a/example/ck_tile/35_batched_transpose/script/smoke_test.sh b/example/ck_tile/35_batched_transpose/script/smoke_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..fdfef2cea8f25fd644619045d95147f991f1e3f7 --- /dev/null +++ b/example/ck_tile/35_batched_transpose/script/smoke_test.sh @@ -0,0 +1,11 @@ +#!/bin/sh + +EXE=./build/bin/tile_example_batched_transpose + +for pr in "fp32" "fp16" "int8" ; do +$EXE -pr=$pr -N=1 -C=32 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=2 -C=12 -H=1 -W=32 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -N=3 -C=1334 -H=1 -W=37 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -N=4 -C=27 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=5 -C=1234 -H=1 -W=12 -layout_in='NCHW' -layout_out='NHWC' +done diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 15db0f46c4ec78e047415a58666653397a2f24f8..7f4ba2ed359503de2aae3227976165e41befd677 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -13,3 +13,8 @@ add_subdirectory(10_rmsnorm2d) add_subdirectory(11_add_rmsnorm2d_rdquant) add_subdirectory(12_smoothquant) add_subdirectory(13_moe_sorting) +add_subdirectory(14_moe_smoothquant) +add_subdirectory(15_fused_moe) +add_subdirectory(16_batched_gemm) +add_subdirectory(17_grouped_gemm) +add_subdirectory(35_batched_transpose) diff --git a/include/ck/README.md b/include/ck/README.md new file mode 100644 index 0000000000000000000000000000000000000000..92d5a510873685549f46695eeade126aa5c08186 --- /dev/null +++ b/include/ck/README.md @@ -0,0 +1,23 @@ +[Back to the main page](../../README.md) +# Composable Kernel supported operations +## Supported device operations + + + + + + + + +* [GEMM](../../client_example/01_gemm/README.md) +* [Grouped Convolution Forward](../../client_example/07_grouped_convnd_fwd/README.md) +* [Grouped Convolution Backward Data](../../client_example/10_grouped_convnd_bwd_data/README.md) +* [Grouped Convolution Backward Weight](../../client_example/11_grouped_conv_bwd_weight/README.md) + + + + + + + + diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 999eb0229c5ab16f19084a53b1b01266ed7e48c1..1ec0c6bc2338d23cbdb076d097d317cbb11ded52 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -1,11 +1,11 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/config.h" #include "ck/utility/env.hpp" - +#ifndef CK_CODE_GEN_RTC #ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS #include "hip/hip_runtime.h" #include "hip/hip_fp16.h" @@ -14,10 +14,12 @@ // environment variable to enable logging: // export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) - +#endif // to do: add various levels of logging with CK_LOG_LEVEL +#ifndef CK_TIME_KERNEL #define CK_TIME_KERNEL 1 +#endif // constant address space for kernel parameter // https://llvm.org/docs/AMDGPUUsage.html#address-spaces @@ -53,10 +55,10 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) // define general macros for various architectures #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) + defined(__gfx942__) || defined(__gfx950__) #define __gfx9__ #endif -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__) #define __gfx94__ #endif #if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__) @@ -155,9 +157,22 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) // LDS direct loads using inline assembly #define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0 +// set rounding to nearest even as default for bf16 conversions +#define CK_USE_RNE_BF16_CONVERSION 1 + // set rounding to nearest even as default for f8 conversions #define CK_USE_SR_F8_CONVERSION 0 +// set rounding to nearest even as default for f6 conversions +#define CK_USE_SR_F6_CONVERSION 0 + +// set rounding to nearest even as default for f4 conversions +#define CK_USE_SR_F4_CONVERSION 0 + +// shuffle pk_i4 values during conversion to optimize number of binary +// operations +#define CK_USE_PK4_LAYOUT_SHUFFLE 1 + // block synchronization only s_wait lgkmcnt(0), not vmcnt(0) #define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 @@ -230,13 +245,18 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) // workaround: compiler issue on gfx908 #define CK_WORKAROUND_SWDEV_388832 1 -// denorm test fix, required to work around dissue -#ifndef CK_WORKAROUND_DENORM_FIX -#define CK_WORKAROUND_DENORM_FIX 0 +// denorm test fix, necessary for gfx90a +#ifndef CK_GFX90A_DENORM_WORKAROUND +#define CK_GFX90A_DENORM_WORKAROUND 0 +#endif // CK_GFX90A_DENORM_WORKAROUND +// Enable only for gfx90a +#if defined(__gfx90a__) +#if CK_GFX90A_DENORM_WORKAROUND +#define CK_GFX90A_DENORM_WORKAROUND 1 +#endif // CK_GFX90A_DENORM_WORKAROUND is set to 1 #else -// enable only for gfx90a -#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__) -#endif // CK_WORKAROUND_DENORM_FIX +#define CK_GFX90A_DENORM_WORKAROUND 0 +#endif // gfx90a // set flag to 1 to build deprecated instances #define CK_BUILD_DEPRECATED 1 diff --git a/include/ck/config.h.in b/include/ck/config.h.in index 0f0b7bd607424a2468f4b62a6cbf1ae6331ee52f..994e60025d08f2997769e2c1bb85d4e0f1135dfa 100644 --- a/include/ck/config.h.in +++ b/include/ck/config.h.in @@ -97,6 +97,10 @@ #cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@ #endif +#ifndef CK_ENABLE_DPP_KERNELS +#cmakedefine CK_ENABLE_DPP_KERNELS @CK_ENABLE_DPP_KERNELS@ +#endif + // // CK kernels which support XDL (MI series) // @@ -111,6 +115,26 @@ #cmakedefine CK_USE_WMMA @CK_USE_WMMA@ #endif +#ifndef CK_USE_GFX94 +#cmakedefine CK_USE_GFX94 @CK_USE_GFX94@ +#endif + +#ifndef CK_USE_OCP_FP8 +#cmakedefine CK_USE_OCP_FP8 @CK_USE_OCP_FP8@ +#endif + +#ifndef CK_USE_FNUZ_FP8 +#cmakedefine CK_USE_FNUZ_FP8 @CK_USE_FNUZ_FP8@ +#endif + +#ifndef CK_USE_FP8_ON_UNSUPPORTED_ARCH +#cmakedefine CK_USE_FP8_ON_UNSUPPORTED_ARCH @CK_USE_FP8_ON_UNSUPPORTED_ARCH@ +#endif + +#ifndef CK_USE_NATIVE_MX_SUPPORT +#cmakedefine CK_USE_NATIVE_MX_SUPPORT @CK_USE_NATIVE_MX_SUPPORT@ +#endif + // clang-format on #endif // CK_CONFIG_H_IN diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index f5c4b43ad2464117d476fed38329b0781f77e160..05dc491af779d828cf38c7510c0f510c19ace8d5 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -55,20 +55,21 @@ inline bool is_xdl_supported() { return ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || - ck::get_device_name() == "gfx942"; + ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"; } inline bool is_lds_direct_load_supported() { // Check if direct loads from global memory to LDS are supported. return ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" || - ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942"; + ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942" || + ck::get_device_name() == "gfx950"; } inline bool is_bf16_atomic_supported() { return ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || - ck::get_device_name() == "gfx942"; + ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"; } inline bool is_gfx101_supported() diff --git a/library/include/ck/library/utility/algorithm.hpp b/include/ck/library/utility/algorithm.hpp similarity index 100% rename from library/include/ck/library/utility/algorithm.hpp rename to include/ck/library/utility/algorithm.hpp diff --git a/library/include/ck/library/utility/check_err.hpp b/include/ck/library/utility/check_err.hpp similarity index 79% rename from library/include/ck/library/utility/check_err.hpp rename to include/ck/library/utility/check_err.hpp index 08bfefb87f072876e5c6cdbb45db72181056dcc9..d33ecaeef8b8f6402b19e25ddeda90e3d2dee422 100644 --- a/library/include/ck/library/utility/check_err.hpp +++ b/include/ck/library/utility/check_err.hpp @@ -26,6 +26,7 @@ namespace utils { template double get_relative_threshold(const int number_of_accumulations = 1) { + using F4 = ck::f4_t; using F8 = ck::f8_t; using F16 = ck::half_t; using BF16 = ck::bhalf_t; @@ -33,10 +34,10 @@ double get_relative_threshold(const int number_of_accumulations = 1) using I8 = int8_t; using I32 = int32_t; - static_assert(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v, + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, "Warning: Unhandled ComputeDataType for setting up the relative threshold!"); double compute_error = 0; if constexpr(is_same_v || is_same_v || @@ -49,10 +50,10 @@ double get_relative_threshold(const int number_of_accumulations = 1) compute_error = std::pow(2, -NumericUtils::mant) * 0.5; } - static_assert(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v, + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, "Warning: Unhandled OutDataType for setting up the relative threshold!"); double output_error = 0; if constexpr(is_same_v || is_same_v || @@ -66,10 +67,10 @@ double get_relative_threshold(const int number_of_accumulations = 1) } double midway_error = std::max(compute_error, output_error); - static_assert(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v, + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, "Warning: Unhandled AccDataType for setting up the relative threshold!"); double acc_error = 0; if constexpr(is_same_v || is_same_v || @@ -87,6 +88,7 @@ double get_relative_threshold(const int number_of_accumulations = 1) template double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1) { + using F4 = ck::f4_t; using F8 = ck::f8_t; using F16 = ck::half_t; using BF16 = ck::bhalf_t; @@ -94,10 +96,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of using I8 = int8_t; using I32 = int32_t; - static_assert(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v, + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); auto expo = std::log2(std::abs(max_possible_num)); double compute_error = 0; @@ -111,10 +113,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of compute_error = std::pow(2, expo - NumericUtils::mant) * 0.5; } - static_assert(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v, + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, "Warning: Unhandled OutDataType for setting up the absolute threshold!"); double output_error = 0; if constexpr(is_same_v || is_same_v || @@ -128,10 +130,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of } double midway_error = std::max(compute_error, output_error); - static_assert(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v, + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, "Warning: Unhandled AccDataType for setting up the absolute threshold!"); double acc_error = 0; if constexpr(is_same_v || is_same_v || @@ -450,5 +452,54 @@ check_err(const Range& out, return res; } +template +std::enable_if_t<(std::is_same_v, ranges::range_value_t> && + std::is_same_v, f4_t>), + bool> +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 0.5, + double atol = 0.5) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits::min(); + + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = type_convert(*std::next(std::begin(out), i)); + const double r = type_convert(*std::next(std::begin(ref), i)); + err = std::abs(o - r); + + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + } + } + + if(!res) + { + std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err + << " number of errors: " << err_count << std::endl; + } + return res; +} + } // namespace utils } // namespace ck diff --git a/library/include/ck/library/utility/conv_common.hpp b/include/ck/library/utility/conv_common.hpp similarity index 100% rename from library/include/ck/library/utility/conv_common.hpp rename to include/ck/library/utility/conv_common.hpp diff --git a/library/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp b/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp similarity index 100% rename from library/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp rename to include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp diff --git a/library/include/ck/library/utility/convolution_parameter.hpp b/include/ck/library/utility/convolution_parameter.hpp similarity index 100% rename from library/include/ck/library/utility/convolution_parameter.hpp rename to include/ck/library/utility/convolution_parameter.hpp diff --git a/library/include/ck/library/utility/device_memory.hpp b/include/ck/library/utility/device_memory.hpp similarity index 100% rename from library/include/ck/library/utility/device_memory.hpp rename to include/ck/library/utility/device_memory.hpp diff --git a/library/include/ck/library/utility/fill.hpp b/include/ck/library/utility/fill.hpp similarity index 100% rename from library/include/ck/library/utility/fill.hpp rename to include/ck/library/utility/fill.hpp diff --git a/library/include/ck/library/utility/host_common_util.hpp b/include/ck/library/utility/host_common_util.hpp similarity index 100% rename from library/include/ck/library/utility/host_common_util.hpp rename to include/ck/library/utility/host_common_util.hpp diff --git a/library/include/ck/library/utility/host_gemm.hpp b/include/ck/library/utility/host_gemm.hpp similarity index 100% rename from library/include/ck/library/utility/host_gemm.hpp rename to include/ck/library/utility/host_gemm.hpp diff --git a/library/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp similarity index 86% rename from library/include/ck/library/utility/host_tensor.hpp rename to include/ck/library/utility/host_tensor.hpp index a58acaf11656c91d92c40f952290f85293369455..f1730de0e1f386765e051f0efd0215daae92bde2 100644 --- a/library/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -44,10 +44,19 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim) else os << delim; - if constexpr(std::is_same_v || std::is_same_v) + using RangeType = ck::remove_cvref_t; + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v) { os << ck::type_convert(v); } + else if constexpr(std::is_same_v) + { + const auto packed_floats = ck::type_convert(v); + const ck::vector_type vector_of_floats{packed_floats}; + os << vector_of_floats.template AsType()[ck::Number<0>{}] << delim + << vector_of_floats.template AsType()[ck::Number<1>{}]; + } else { os << static_cast(v); @@ -266,18 +275,18 @@ struct Tensor using Data = std::vector; template - Tensor(std::initializer_list lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize()) + Tensor(std::initializer_list lens) : mDesc(lens), mData(GetElementSpaceSize()) { } template Tensor(std::initializer_list lens, std::initializer_list strides) - : mDesc(lens, strides), mData(mDesc.GetElementSpaceSize()) + : mDesc(lens, strides), mData(GetElementSpaceSize()) { } template - Tensor(const Lengths& lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize()) + Tensor(const Lengths& lens) : mDesc(lens), mData(GetElementSpaceSize()) { } @@ -287,7 +296,7 @@ struct Tensor { } - Tensor(const Descriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpaceSize()) {} + Tensor(const Descriptor& desc) : mDesc(desc), mData(GetElementSpaceSize()) {} template Tensor CopyAsType() const @@ -322,11 +331,21 @@ struct Tensor std::size_t GetElementSize() const { return mDesc.GetElementSize(); } - std::size_t GetElementSpaceSize() const { return mDesc.GetElementSpaceSize(); } + std::size_t GetElementSpaceSize() const + { + if constexpr(ck::is_same_v, ck::pk_i4_t>) + { + return (mDesc.GetElementSpaceSize() + 1) / 2; + } + else + { + return mDesc.GetElementSpaceSize(); + } + } std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); } - void SetZero() { ck::ranges::fill(mData, 0); } + void SetZero() { ck::ranges::fill(mData, T{0}); } template void ForEach_impl(F&& f, std::vector& idx, size_t rank) @@ -469,29 +488,64 @@ struct Tensor template std::size_t GetOffsetFromMultiIndex(Is... is) const { - return mDesc.GetOffsetFromMultiIndex(is...); + if constexpr(ck::is_same_v, ck::pk_i4_t>) + { + return mDesc.GetOffsetFromMultiIndex(is...) / 2; + } + else + { + return mDesc.GetOffsetFromMultiIndex(is...); + } } template T& operator()(Is... is) { - return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + if constexpr(ck::is_same_v, ck::pk_i4_t>) + { + return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; + } + else + { + return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + } } template const T& operator()(Is... is) const { - return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + if constexpr(ck::is_same_v, ck::pk_i4_t>) + { + return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; + } + else + { + return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + } } T& operator()(std::vector idx) { - return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + if constexpr(ck::is_same_v, ck::pk_i4_t>) + { + return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; + } + else + { + return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + } } const T& operator()(std::vector idx) const { - return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + if constexpr(ck::is_same_v, ck::pk_i4_t>) + { + return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; + } + else + { + return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + } } typename Data::iterator begin() { return mData.begin(); } diff --git a/library/include/ck/library/utility/host_tensor_generator.hpp b/include/ck/library/utility/host_tensor_generator.hpp similarity index 71% rename from library/include/ck/library/utility/host_tensor_generator.hpp rename to include/ck/library/utility/host_tensor_generator.hpp index e87811b76bb8e1c681447e9f248c3fa958446c99..274051da83040f760ac1343babccc5b2ff6f2e76 100644 --- a/library/include/ck/library/utility/host_tensor_generator.hpp +++ b/include/ck/library/utility/host_tensor_generator.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -37,7 +37,7 @@ struct GeneratorTensor_1 float value = 1.0; template - ck::bhalf_t operator()(Is...) + ck::half_t operator()(Is...) { return ck::type_convert(value); } @@ -62,13 +62,25 @@ struct GeneratorTensor_1 float value = 1.0; template - ck::bhalf_t operator()(Is...) + ck::f8_t operator()(Is...) { return ck::type_convert(value); } }; #endif +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::f4_t operator()(Is...) + { + return ck::type_convert(value); + } +}; + template <> struct GeneratorTensor_1 { @@ -81,6 +93,20 @@ struct GeneratorTensor_1 } }; +template <> +struct GeneratorTensor_1 +{ + int8_t value = 1; + + template + ck::pk_i4_t operator()(Is...) + { + int t = value + 8; + ck::pk_i4_t r = ((t << 4) + t) & 0xff; + return r; + } +}; + template struct GeneratorTensor_2 { @@ -121,6 +147,22 @@ struct GeneratorTensor_2 } }; +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::pk_i4_t operator()(Is...) + { + int hi = std::rand() % (max_value - min_value) + min_value + 8; + int lo = std::rand() % (max_value - min_value) + min_value + 8; + ck::pk_i4_t r = ((hi << 4) + lo) & 0xff; + return r; + } +}; + #if defined CK_ENABLE_FP8 template <> struct GeneratorTensor_2 @@ -153,6 +195,20 @@ struct GeneratorTensor_2 }; #endif +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::f4_t operator()(Is...) + { + float tmp = (std::rand() % (max_value - min_value)) + min_value; + return ck::type_convert(tmp); + } +}; + template struct GeneratorTensor_3 { @@ -223,6 +279,23 @@ struct GeneratorTensor_3 }; #endif +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::f4_t operator()(Is...) + { + float tmp = float(std::rand()) / float(RAND_MAX); + + float fp32_tmp = min_value + tmp * (max_value - min_value); + + return ck::type_convert(fp32_tmp); + } +}; + template struct GeneratorTensor_4 { @@ -256,14 +329,33 @@ struct GeneratorTensor_Checkboard } }; -template +/** + * @brief Is used to generate sequential values based on the specified dimension. + * + * @tparam T The type of the tensor values. + * @tparam Dim The specific dimension used for generation. + * + * GeneratorTensor_Sequential<1>{} will generate the following values for a 3x3 tensor: + * + * 0 1 2 + * 0 1 2 + * 0 1 2 + * + * Essentially, the values generated are logical coordinates of the generated element that + * correspond to dimension Dim. E.g. for 2-dimensional tensor and Dim=1, the values are the column + * indices. + * + */ +template struct GeneratorTensor_Sequential { template - float operator()(Ts... Xs) const + T operator()(Ts... Xs) const { std::array dims = {{static_cast(Xs)...}}; - return dims[Dim]; + + float tmp = dims[Dim]; + return ck::type_convert(tmp); } }; diff --git a/library/include/ck/library/utility/iterator.hpp b/include/ck/library/utility/iterator.hpp similarity index 100% rename from library/include/ck/library/utility/iterator.hpp rename to include/ck/library/utility/iterator.hpp diff --git a/library/include/ck/library/utility/literals.hpp b/include/ck/library/utility/literals.hpp similarity index 100% rename from library/include/ck/library/utility/literals.hpp rename to include/ck/library/utility/literals.hpp diff --git a/library/include/ck/library/utility/numeric.hpp b/include/ck/library/utility/numeric.hpp similarity index 100% rename from library/include/ck/library/utility/numeric.hpp rename to include/ck/library/utility/numeric.hpp diff --git a/library/include/ck/library/utility/ranges.hpp b/include/ck/library/utility/ranges.hpp similarity index 100% rename from library/include/ck/library/utility/ranges.hpp rename to include/ck/library/utility/ranges.hpp diff --git a/include/ck/tensor/static_tensor.hpp b/include/ck/tensor/static_tensor.hpp index d719ef9760d79297600d7524167eba78cd137831..ef2bedd65cefadf8f68a8eefcdb282f742fab563 100644 --- a/include/ck/tensor/static_tensor.hpp +++ b/include/ck/tensor/static_tensor.hpp @@ -167,7 +167,7 @@ struct StaticTensorTupleOfVectorBuffer // Idx is for S, not X. Idx should be aligned with X template ::value && + typename enable_if<(has_same_scalar_type::value || !is_native_type()) && is_known_at_compile_time::value && Idx::Size() == ndim_, bool>::type = false> __host__ __device__ constexpr X GetAsType(Idx) const @@ -201,7 +201,7 @@ struct StaticTensorTupleOfVectorBuffer // Idx is for S, not X. Idx should be aligned with X template ::value && + typename enable_if<(has_same_scalar_type::value || !is_native_type()) && is_known_at_compile_time::value && Idx::Size() == ndim_, bool>::type = false> __host__ __device__ constexpr void SetAsType(Idx, X x) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ea0c511da37d5af8f263cb044178ea9571d1e22c --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp" + +namespace ck { + +enum struct BlockGemmPipelineVersion +{ + v1, // Naive + v2, // Mem + v3, // Comp + v4, // Comp, double lds buffer + v5, // Comp, double global prefetch register buffer +}; + +template +constexpr auto BlockGemmPipeline_Selector() +{ + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + return BlockwiseGemmXdlops_pipeline_v1_b_scale{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + return BlockwiseGemmXdlops_pipeline_v2_b_scale{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return BlockwiseGemmXdlops_pipeline_v3_b_scale{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + return BlockwiseGemmXdlops_pipeline_v4_b_scale{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v5) + { + return BlockwiseGemmXdlops_pipeline_v5{}; + } + else + { + std::cerr << "BlockGemmPipeline configuration is not available" << std::endl; + } +} + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4246f4a44e76b4c3f2554fcdfc90019771310f7b --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp @@ -0,0 +1,403 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 1 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 0 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v1_b_scale +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v1_b_scale + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::KRepeat; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 1; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num_loop + index_t num_loop, + index_t num_loop_per_scale) const + { + // assume kperblock = scaleblockk + ignore = num_loop_per_scale; + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<0>{})); + }); + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<1>{})); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Initialize C + c_thread_buf.Clear(); + + auto c_thread_buf_per_scale = remove_cvref_t(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + // ------------------------------------------------------------------------------------------- + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(b_scale_thread_buf[n0]); + }); + }); + }); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{})); + }); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<1>{})); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + i += 1; + + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(b_scale_thread_buf[n0]); + }); + }); + }); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp index 711c47854adad7b2880718f69ec3febe05984bb4..54edf0c3533b1ce753ec5910b305141477feb031 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp @@ -269,15 +269,14 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto n0) { - b_thread_copy_.Run( - b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); }); }); @@ -341,14 +340,14 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); }); }); @@ -396,14 +395,14 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); }); }); @@ -447,14 +446,14 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); }); }); @@ -760,15 +759,14 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto n0) { - b_thread_copy_.Run( - b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k0, I0), - b_thread_buf); - }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); }); __builtin_amdgcn_sched_barrier(0); // NOTE: Synchronize threads in a workgroup at the start of each MAC @@ -866,14 +864,14 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k0, I0), - b_thread_buf); - }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); }); __builtin_amdgcn_sched_barrier(0); @@ -942,14 +940,14 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k0, I0), - b_thread_buf); - }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); }); __builtin_amdgcn_sched_barrier(0); @@ -1018,14 +1016,14 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k0, I0), - b_thread_buf); - }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); }); __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp new file mode 100644 index 0000000000000000000000000000000000000000..776f66dbbb8f84c61a5882b92a0f8099db76d1d0 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp @@ -0,0 +1,1248 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Maximum Global Memory throughput pipeline with >=32KB data in fly +// GlobalPrefetchStages: >=2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 0 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v2_b_scale +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v2_b_scale + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::KRepeat; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t WgpPerCU = + (4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1; + static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( + 32768 / WgpPerCU, + (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); + static constexpr index_t PrefetchStages = + FullMemBandPrefetchStages >= 2 + ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8 + : 2; + + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = PrefetchStages; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + if(num_loop % PrefetchStages == 1) + { + return TailNumber::One; + } + else if(num_loop % PrefetchStages == 2) + { + return TailNumber::Two; + } + else if(num_loop % PrefetchStages == 3) + { + return TailNumber::Three; + } + else if(num_loop % PrefetchStages == 4) + { + return TailNumber::Four; + } + else if(num_loop % PrefetchStages == 5) + { + return TailNumber::Five; + } + else if(num_loop % PrefetchStages == 6) + { + return TailNumber::Six; + } + else if(num_loop % PrefetchStages == 7) + { + return TailNumber::Seven; + } + else + { + return TailNumber::Full; + } + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Global prefetch [2, PrefetchStages] + static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) { + // ------------------------------------------------------------------------------------------- + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + a_blockwise_copy.RunWrite( + a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + b_blockwise_copy.RunWrite( + b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + + i += PrefetchStages; + } while(i < (num_loop - PrefetchStages)); + } + + // tail + + auto LoopTailFunc = [&](auto tail_num) { + static_for<1, tail_num, 1>{}([&](auto iprefetch) { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch); + }); + + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }; + + if constexpr(TailNum == TailNumber::One) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Two) + { + LoopTailFunc(Number<2>{}); + } + else if constexpr(TailNum == TailNumber::Three) + { + LoopTailFunc(Number<3>{}); + } + else if constexpr(TailNum == TailNumber::Four) + { + LoopTailFunc(Number<4>{}); + } + else if constexpr(TailNum == TailNumber::Five) + { + LoopTailFunc(Number<5>{}); + } + else if constexpr(TailNum == TailNumber::Six) + { + LoopTailFunc(Number<6>{}); + } + else if constexpr(TailNum == TailNumber::Seven) + { + LoopTailFunc(Number<7>{}); + } + else if constexpr(TailNum == TailNumber::Full) + { + LoopTailFunc(Number{}); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +template +struct BlockwiseGemmXdlops_pipeline_v2_b_scale + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KPerThread; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS; + static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack); + static constexpr index_t KRepeat = KPerThread / KPerInnerLoop; + + static constexpr index_t WgpPerCU = + (4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1; + static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( + 32768 / WgpPerCU, + (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); + static constexpr index_t PrefetchStages = + FullMemBandPrefetchStages >= 2 + ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8 + : 2; + + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = PrefetchStages; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + if(num_loop % PrefetchStages == 1) + { + return TailNumber::One; + } + else if(num_loop % PrefetchStages == 2) + { + return TailNumber::Two; + } + else if(num_loop % PrefetchStages == 3) + { + return TailNumber::Three; + } + else if(num_loop % PrefetchStages == 4) + { + return TailNumber::Four; + } + else if(num_loop % PrefetchStages == 5) + { + return TailNumber::Five; + } + else if(num_loop % PrefetchStages == 6) + { + return TailNumber::Six; + } + else if(num_loop % PrefetchStages == 7) + { + return TailNumber::Seven; + } + else + { + return TailNumber::Full; + } + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + const BScaleGridDesc& b_scale_grid_desc, + // BScaleThreadCopy + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num loop + index_t num_loop, + index_t num_loop_per_scale) const + { + ignore = num_loop_per_scale; + + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<0>{})); + }); + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<1>{})); + + // Initialize C + c_thread_buf.Clear(); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Global prefetch [2, PrefetchStages] + static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + + auto c_thread_buf_per_scale = remove_cvref_t(); // need? + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) { + // ------------------------------------------------------------------------------------------- + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + __builtin_amdgcn_sched_barrier(0); + // NOTE: Synchronize threads in a workgroup at the start of each MAC + // cluster, but except the first, as we can shorten non-MAC cluster a bit + // and there's no observable negative impact. The desired effect is waves in + // a workgroup executing MAC in sync. This avoids some out-of-sync waves + // hijacking MAC resource from other workgroups and reducing the chance of + // latency hiding by waiting for the rest of the workgroup at the eventual + // sync point. + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard because barrier from + // blockwise_gemm is moved here B) reduce VMEM FIFO congestion + // by applying small delays to different wavefronts It is + // performed near the end of MAC cluster to minimize lgkmcnt + // penalty + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + + // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) + // { + // constexpr index_t c_offset = + // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + // c_thread_buf(Number{}) += + // c_thread_buf_per_scale[Number{}] * + // type_convert(b_scale_thread_buf[n0]); + // }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + + // static_for<0, NRepeat, 1>{}([&](auto n0) { + // b_scale_thread_copy.Run(b_scale_grid_desc, + // b_scale_grid_buf, + // b_scale_thread_desc, + // make_tuple(n0, I0), + // b_scale_thread_buf); + + // b_scale_thread_copy.MoveSrcSliceWindow( + // b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{})); + // }); + // b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + // b_scale_thread_copy_step.At(Number<1>{})); + + // block_sync_lds(); + a_blockwise_copy.RunWrite( + a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + b_blockwise_copy.RunWrite( + b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + i += PrefetchStages; + } while(i < (num_loop - PrefetchStages)); + } + + // tail + + auto LoopTailFunc = [&](auto tail_num) { + static_for<1, tail_num, 1>{}([&](auto iprefetch) { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + + // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + // constexpr index_t c_offset = + // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + // c_thread_buf(Number{}) += + // c_thread_buf_per_scale[Number{}] * + // type_convert(b_scale_thread_buf[n0]); + // }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + + // static_for<0, NRepeat, 1>{}([&](auto n0) { + // b_scale_thread_copy.Run(b_scale_grid_desc, + // b_scale_grid_buf, + // b_scale_thread_desc, + // make_tuple(n0, I0), + // b_scale_thread_buf); + + // b_scale_thread_copy.MoveSrcSliceWindow( + // b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{})); + // }); + // b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + // b_scale_thread_copy_step.At(Number<1>{})); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch); + }); + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + + // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + // constexpr index_t c_offset = + // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + // c_thread_buf(Number{}) += + // c_thread_buf_per_scale[Number{}] * + // type_convert(b_scale_thread_buf[n0]); + // }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + }; + + if constexpr(TailNum == TailNumber::One) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + + // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + // constexpr index_t c_offset = + // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + // c_thread_buf(Number{}) += + // c_thread_buf_per_scale[Number{}] * + // type_convert(b_scale_thread_buf[n0]); + // }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + } + else if constexpr(TailNum == TailNumber::Two) + { + LoopTailFunc(Number<2>{}); + } + else if constexpr(TailNum == TailNumber::Three) + { + LoopTailFunc(Number<3>{}); + } + else if constexpr(TailNum == TailNumber::Four) + { + LoopTailFunc(Number<4>{}); + } + else if constexpr(TailNum == TailNumber::Five) + { + LoopTailFunc(Number<5>{}); + } + else if constexpr(TailNum == TailNumber::Six) + { + LoopTailFunc(Number<6>{}); + } + else if constexpr(TailNum == TailNumber::Seven) + { + LoopTailFunc(Number<7>{}); + } + else if constexpr(TailNum == TailNumber::Full) + { + LoopTailFunc(Number{}); + } + } + + protected: + // K->M loopover + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + I1)); + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + I1)); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()}; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d1be88dd632fa3d878cdd6479d8e7445ceb8217f --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp @@ -0,0 +1,530 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v3_b_scale +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v3_b_scale + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + // Separate this part? + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > + // sizeof(ComputeDataType) / sizeof(BDataType) + // ? sizeof(ComputeDataType) / sizeof(ADataType) + // : sizeof(ComputeDataType) / sizeof(BDataType); + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = + num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * + ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num loop + index_t num_loop, + index_t num_loop_per_scale) const + { + __builtin_amdgcn_sched_barrier(0); + + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // B scale buffer + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<0>{})); + }); + + if(num_loop_per_scale == 1) + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<2>{})); + } + else + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<1>{})); + } + + constexpr auto num_scale_k_block = BScaleThreadDesc{}.GetLength(I1); + constexpr auto num_scale_krepeat = KRepeat / num_scale_k_block; + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // Local prefetch 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_scale_thread_buf[Number{}], + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{})); + }); + + if((i + 2) % num_loop_per_scale == 0) + { + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, b_scale_thread_copy_step.At(Number<2>{})); + } + else + { + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, b_scale_thread_copy_step.At(Number<1>{})); + } + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_scale_thread_buf[Number{}], + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 1; + } while(i < (num_loop - 1)); + } + // tail + if constexpr(TailNum == TailNumber::Full) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp index bd5a1bedf537c3a4a31d53cae5f2d5ca1beeabb9..e8d10511110056c07c92cc1bf79d1a62c195abde 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp @@ -305,14 +305,14 @@ struct BlockwiseGemmXdlops_pipeline_v4{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(I0), - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(I0)); - }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(I0), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(I0)); }); }); @@ -356,15 +356,14 @@ struct BlockwiseGemmXdlops_pipeline_v4{}([&](auto n0) { - b_thread_copy_.Run( - b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(lds_read_buf), - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(lds_read_reg_buf)); - }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); }); }); @@ -437,14 +436,14 @@ struct BlockwiseGemmXdlops_pipeline_v4{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(lds_read_buf), - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(lds_read_reg_buf)); - }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); }); }); @@ -496,14 +495,14 @@ struct BlockwiseGemmXdlops_pipeline_v4{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(lds_read_buf), - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(lds_read_reg_buf)); - }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); }); }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f35c7a97cc323e438ded5e120ed1b5c39a3d3474 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp @@ -0,0 +1,686 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimimal pipeline with highest resource request +// GlobalPrefetchStages: 4 +// LocalPreFillStages: 2 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 2 + +template +struct BlockwiseGemmXdlops_pipeline_v4_b_scale +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v4_b_scale + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 3; + static constexpr index_t PrefillStages = 2; + static constexpr index_t GlobalBufferNum = 1; + static constexpr index_t HotloopUnroll = 2; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + if(num_loop % HotloopUnroll == 1) + { + return TailNumber::Odd; + } + else + { + return TailNumber::Even; + } + } + + __device__ static constexpr void HotLoopScheduler() + { + // TODO: Take data type into consideration as pipe ver 3 + // A-B splited schedule + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_issue_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_dswrite_per_issue_a = + (HotLoopInstList::A_LDS_Write_Inst_Num + num_issue_a - 1) / num_issue_a; + constexpr auto num_dsread_per_issue_a = num_ds_read_inst_a / num_issue_a; + + constexpr auto num_issue_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + constexpr auto num_dswrite_per_issue_b = + (HotLoopInstList::B_LDS_Write_Inst_Num + num_issue_b - 1) / num_issue_b; + constexpr auto num_dsread_per_issue_b = num_ds_read_inst_b / num_issue_b; + + constexpr auto num_mfma_per_issue = + HotLoopInstList::C_MFMA_Inst_Num / (num_issue_a + num_issue_b); + + static_for<0, num_issue_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dsread_per_issue_a, 1>{}([&](auto idsread) { + ignore = idsread; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_dsread_per_issue_a - + num_dswrite_per_issue_a, + 0); // MFMA + }); + + static_for<0, num_issue_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dsread_per_issue_b, 1>{}([&](auto idsread) { + ignore = idsread; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_dsread_per_issue_a - + num_dswrite_per_issue_b, + 0); // MFMA + }); + __builtin_amdgcn_sched_barrier(0); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num loop + index_t num_loop, + index_t num_loop_per_scale) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // B scale buffer + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<0>{})); + }); + + if(num_loop_per_scale == 1) + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<2>{})); + } + else + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<1>{})); + } + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0)); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_bufs(I1)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<0>{})); + }); + + if(2 % num_loop_per_scale == 0) + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<2>{})); + } + else + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<1>{})); + } + + // Local prefetch 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(I0)); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(I0), + b_scale_thread_bufs(I0)[n0], + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(I0)); + }); + }); + }); + + // Local prefill 2 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1)); + + // Global prefetch 3 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<0>{})); + }); + + if(3 % num_loop_per_scale == 0) + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<2>{})); + } + else + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<1>{})); + } + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + // This hot loop has two legacy loopover, to implement the double local buffer strategy + do + { + auto LoopFunc = [&](auto lds_read_buf, + auto lds_read_reg_buf, + auto lds_write_buf, + auto mfma_reg_buf) { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(lds_read_buf), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(lds_read_reg_buf)); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_scale_thread_bufs(lds_read_buf)[n0], + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); + }); + }); + + // B scale copy + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_bufs(lds_read_reg_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{})); + }); + + if((i + 4 + mfma_reg_buf.value) % num_loop_per_scale == 0) + { + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, b_scale_thread_copy_step.At(Number<2>{})); + } + else + { + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, b_scale_thread_copy_step.At(Number<1>{})); + } + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf)); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf] + [Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + }; + + LoopFunc(I1, I1, I0, I0); + LoopFunc(I0, I0, I1, I1); + + i += HotloopUnroll; + } while(i < (num_loop - PrefetchStages)); + } + + auto ReadWriteCompFunc = [&](auto lds_read_buf, + auto lds_read_reg_buf, + auto lds_write_buf, + auto mfma_reg_buf) { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(lds_read_buf), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(lds_read_reg_buf)); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_scale_thread_bufs(lds_read_buf)[n0], + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf)); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + }; + + auto ReadCompFunc = [&](auto lds_read_buf, auto lds_read_reg_buf, auto mfma_reg_buf) { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(lds_read_buf), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(lds_read_reg_buf)); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_scale_thread_bufs(lds_read_buf)[n0], + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + }; + + auto CompFunc = [&](auto mfma_reg_buf) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }; + + // tail + if constexpr(TailNum == TailNumber::Odd) + { + ReadWriteCompFunc(I1, I1, I0, I0); + ReadCompFunc(I0, I0, I1); + CompFunc(I0); + } + else if constexpr(TailNum == TailNumber::Even) + { + ReadCompFunc(I1, I1, I0); + CompFunc(I1); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp index 1c4de5ed3153f2cc8ca7f4a8ccf57741085fba40..0a0bcbac38c2a732f0d9c2f342797e19c0400845 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -131,7 +131,7 @@ struct ThreadGroupTensorSliceTransfer_v7r2 } template - using is_tuple = decltype(std::declval().IsTuple()); + using is_tuple = decltype(ck::declval().IsTuple()); template __device__ void RunWrite(const DstDescs& dst_descs, diff --git a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp index 0eef827a5b5018efeee94d11f1b768c739d85648..cf20025d46e1ac0ba0de1c529b3fea5d90621915 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp @@ -1,9 +1,11 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#ifndef CK_CODE_GEN_RTC #include +#endif namespace ck { namespace tensor_operation { @@ -18,6 +20,7 @@ enum struct ConvolutionForwardSpecialization Filter3x3, }; +#ifndef CK_CODE_GEN_RTC inline std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization& s) { switch(s) @@ -30,6 +33,7 @@ inline std::string getConvForwardSpecializationString(const ConvolutionForwardSp default: return "Unrecognized specialization!"; } } +#endif } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index 908ada016d4f4ae5b36177a91cdd39781080adb4..774982d905fb49551654b574787a5b0fc2bbde81 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -1,17 +1,51 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#ifndef CK_CODE_GEN_RTC #include #include - +#include +#include #include "ck/stream_config.hpp" +#endif namespace ck { namespace tensor_operation { namespace device { +#ifndef CK_CODE_GEN_RTC +#define GET_OBJECT_NAME_IMLP \ + std::optional GetObjectName() const override \ + { \ + std::string str = __PRETTY_FUNCTION__; \ + static std::regex obj_name_expr{" (.*)::GetObjectName"}; \ + std::smatch match; \ + if(!std::regex_search(str, match, obj_name_expr)) \ + { \ + return str; \ + } \ + return std::string(match[1]) + ';'; \ + } + +#define GET_TEMPLATE_INFO_IMPL \ + std::optional GetTemplateInfo() const override \ + { \ + std::string str = __PRETTY_FUNCTION__; \ + static std::regex template_expr{"\\[(.*)\\]"}; \ + std::smatch match; \ + if(!std::regex_search(str, match, template_expr)) \ + { \ + return std::nullopt; \ + } \ + return std::string(match[1]); \ + } + +#define REGISTER_EXTRA_PRINTING_METHODS GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL +#endif + +#ifndef CK_CODE_GEN_RTC struct BaseArgument { BaseArgument() = default; @@ -36,18 +70,23 @@ struct BaseInvoker virtual ~BaseInvoker() {} }; +#endif struct BaseOperator { BaseOperator() = default; BaseOperator(const BaseOperator&) = default; BaseOperator& operator=(const BaseOperator&) = default; - +#ifndef CK_CODE_GEN_RTC virtual bool IsSupportedArgument(const BaseArgument*) { return false; } virtual std::string GetTypeString() const { return ""; } virtual std::string GetTypeIdName() const { return typeid(*this).name(); } + virtual std::optional GetObjectName() const { return std::nullopt; } + + virtual std::optional GetTemplateInfo() const { return std::nullopt; } + virtual std::string GetTypeIdHashCode() const { std::ostringstream oss; @@ -66,7 +105,7 @@ struct BaseOperator assert(p_arg); p_arg->p_workspace_ = p_workspace; } - +#endif virtual ~BaseOperator() {} }; diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp index 6cc2c7bb2f6c176f2d84fdda4be2140db5564360..fcb46082933771ce7f2a24485d30ea54e75f0680 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp @@ -44,6 +44,48 @@ struct DeviceBatchedGemm : public BaseOperator virtual std::unique_ptr MakeInvokerPointer() = 0; }; +template +struct DeviceBatchedGemmV2BScale : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + ck::index_t StrideScaleB, + ck::index_t BatchStrideA, + ck::index_t BatchStrideB, + ck::index_t BatchStrideC, + ck::index_t BatchStrideScaleB, + const void* p_b_scale, + ck::index_t Batch, + ck::index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual bool GetPermuteB() = 0; + virtual ck::index_t GetKPerBlock() = 0; +}; + template MakeInvokerPointer() = 0; }; diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp index b2db35b159eb709bb3a12dc6c09d7426aa1a323a..78d8aa997e2b02b05a6aee9f479fde21bbc38ce2 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp @@ -36,6 +36,10 @@ struct DeviceGemmV2 : public BaseOperator CElementwiseOperation c_element_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual bool GetPermuteA() = 0; + virtual bool GetPermuteB() = 0; + virtual ck::index_t GetKPerBlock() = 0; }; template MakeInvokerPointer() = 0; }; +template +struct DeviceGemmV2BScale : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + ck::index_t StrideScaleB, + const void* p_b_scale, + ck::index_t KSplit, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual bool GetPermuteB() = 0; + virtual ck::index_t GetKPerBlock() = 0; +}; + } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp index 184efbbd68ecea5b3c2b36f52764690e3ad316da..8c9b768a8b9c75ad5da9859c72d53b3d98a5967e 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp @@ -1,9 +1,11 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#ifndef CK_CODE_GEN_RTC #include +#endif #include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" @@ -13,8 +15,13 @@ namespace ck { namespace tensor_operation { namespace device { +#ifdef CK_CODE_GEN_RTC +template +using is_tuple = decltype(ck::declval().IsTuple()); +#else template using is_tuple = decltype(std::declval().IsTuple()); +#endif /** * \brief Grouped Convolution Forward @@ -72,12 +79,18 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator static constexpr index_t NumDTensor = DsDataType::Size(); static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor"); - +#ifdef CK_CODE_GEN_RTC + using APointers = ck::conditional_t&, const void*>; + using BPointers = ck::conditional_t&, const void*>; +#else // If DataType is tuple, user has to pass std::array with pointers. using APointers = - std::conditional_t&, const void*>; + ck::conditional_t&, const void*>; using BPointers = - std::conditional_t&, const void*>; + ck::conditional_t&, const void*>; +#endif + +#ifndef CK_CODE_GEN_RTC /** * \brief Make argument pointer for grouped conv fwd. @@ -150,6 +163,7 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator const CDEElementwiseOperation& cde_element_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; +#endif }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp index 1e03405536a20a9f050a7943243f1c406594b836..267a970ee5e60a5a644fa22509634588501e0954 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp @@ -1,17 +1,87 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include #include +#include +#include #include #include "device_base.hpp" +#include "ck/utility/ignore.hpp" namespace ck { namespace tensor_operation { namespace device { +/// +/// @brief Structure representing single GEMM problem arguments. +/// +/// The pointer to the vector of those structures is passed to the GroupedGEMM entry +/// point kernel. +/// +/// @tparam NumDTensor The number of D input tensors. +/// +template +struct GroupedGemmKernelArgument +{ + __host__ __device__ GroupedGemmKernelArgument(const void* p_a_grid_, + const void* p_b_grid_, + std::array p_ds_grid_, + void* p_e_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideE_) + : p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_ds_grid{p_ds_grid_}, + p_e_grid{p_e_grid_}, + M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideDs{StrideDs_}, + StrideE{StrideE_} + { + } + + const void* p_a_grid; + const void* p_b_grid; + std::array p_ds_grid; + void* p_e_grid; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideE; + + void Print() const + { + std::stringstream str; + for(auto sd : StrideDs) + str << sd << ","; + + std::cout << "arg {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SE:" << StrideE << ", " + << "SDs: {" << str.str() << "}" + << "}" << std::endl; + } +}; + struct GemmDesc { ck::index_t M_, N_, K_; @@ -48,6 +118,66 @@ struct DeviceGroupedGemm : public BaseOperator CElementwiseOperation c_element_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; + + //--------------------------------------------------------------------------------------------- + /// @brief Sets the device kernel arguments pointer and may copy data to device. + /// + /// TODO: Add which kernels are using this (TileLoop * FixedNK ??) + /// + /// @param p_arg The pointer to the Argument we're going to update. + /// @param[in] p_dev_kernel_args The pointer to the device memory which will contain kernel + /// arguments. + /// @param[in] p_host_kernel_args The pointer to the host memory which contains kernel + /// arguments that should be copied to device memory. + /// + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, + void* p_dev_kernel_args, + const void* p_host_kernel_args) const + { + ignore = p_arg; + ignore = p_dev_kernel_args; + ignore = p_host_kernel_args; + + std::ostringstream err; + err << "This function is not implemented by the kernel: " << this->GetTypeString() + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + //---------------------------------------------------------------------------------------------- + /// @brief Sets the device kernel arguments pointer and may copy data to device. + /// + /// @param p_arg The pointer to the Argument we're going to update. + /// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel + /// arguments. + /// + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const + { + ignore = p_arg; + ignore = p_dev_kernel_args; + + std::ostringstream err; + err << "This function is not implemented by the kernel: " << this->GetTypeString() + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + //---------------------------------------------------------------------------------------------- + /// @brief Gets the device kernel argument size. + /// + /// @param[in] p_arg The pointer to the Device op Argument. + /// + /// @return The device kernel argument size. + /// + virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const + { + ignore = p_arg; + + std::ostringstream err; + err << "This function is not implemented by the kernel: " << this->GetTypeString() + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp index fcb2ba6a4d7be839c11cf9794bb7beccf7845d3c..780a0c30c50fa09ea755f94bea4f6eb03fe5ad5d 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp @@ -1,35 +1,14 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include -#include - -#include "device_grouped_gemm.hpp" +#include "device_grouped_gemm_splitk.hpp" namespace ck { namespace tensor_operation { namespace device { -template -struct GroupedGemmKernelArgument -{ - const void* p_a_grid; - const void* p_b_grid; - std::array p_ds_grid; - void* p_e_grid; - - index_t M; - index_t N; - index_t K; - - index_t StrideA; - index_t StrideB; - std::array StrideDs; - index_t StrideE; -}; - template -struct DeviceGroupedGemmFixedNK : DeviceGroupedGemm +struct DeviceGroupedGemmFixedNK : DeviceGroupedGemmSplitK { - virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const = 0; - virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0; - virtual void SetKBatch(BaseArgument* p_arg, index_t k_batch) const = 0; }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp deleted file mode 100644 index d91eac07302fc31a662a9c55c5a5a6d9894bd7d0..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp +++ /dev/null @@ -1,136 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include -#include - -#include "device_grouped_gemm.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -/// -/// @brief Structure representing single GEMM problem arguments. -/// -/// The pointer to the vector of those structures is passed to the GroupedGEMM entry -/// point kernel. -/// -/// @tparam NumDTensor The number of D input tensors. -/// -template -struct GroupedGemmMultipleDKernelArguments -{ - __host__ __device__ - GroupedGemmMultipleDKernelArguments(const void* p_a_grid_, - const void* p_b_grid_, - std::array p_ds_grid_, - void* p_e_grid_, - index_t M_, - index_t N_, - index_t K_, - index_t StrideA_, - index_t StrideB_, - std::array StrideDs_, - index_t StrideE_) - : p_a_grid{p_a_grid_}, - p_b_grid{p_b_grid_}, - p_ds_grid{p_ds_grid_}, - p_e_grid{p_e_grid_}, - M{M_}, - N{N_}, - K{K_}, - StrideA{StrideA_}, - StrideB{StrideB_}, - StrideDs{StrideDs_}, - StrideE{StrideE_} - { - } - - const void* p_a_grid; - const void* p_b_grid; - std::array p_ds_grid; - void* p_e_grid; - index_t M; - index_t N; - index_t K; - index_t StrideA; - index_t StrideB; - std::array StrideDs; - index_t StrideE; - - void Print() const - { - std::stringstream str; - for(auto sd : StrideDs) - str << sd << ","; - - std::cout << "arg {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SE:" << StrideE << ", " - << "SDs: {" << str.str() << "}" - << "}" << std::endl; - } -}; - -template -struct DeviceGroupedGemmMultipleDSplitK : public DeviceGroupedGemm -{ - //---------------------------------------------------------------------------------------------- - /// @brief Sets the k batch size. - /// - /// @param p_arg Pointer to the Argument we're going to change. - /// @param[in] kbatch The kbatch value. - /// - virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0; - - //---------------------------------------------------------------------------------------------- - /// @brief Sets the device kernel arguments pointer. - /// - /// @param p_arg The pointer to the Argument we're going to update. - /// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel - /// arguments. - /// - virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const = 0; - - //---------------------------------------------------------------------------------------------- - /// @brief Gets the device kernel argument size. - /// - /// @param[in] p_arg The pointer to the Device op Argument. - /// - /// @return The device kernel argument size. - /// - virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0; -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp index 06d180d30fb0f3c8e3b348dd1668ff582de87493..3ea6501902712e48e2017fefe766c670bc0f5a36 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp @@ -1,6 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include -#include #include "device_grouped_gemm.hpp" @@ -31,7 +31,23 @@ struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm { + //---------------------------------------------------------------------------------------------- + /// @brief Sets the k batch size. + /// + /// @param p_arg Pointer to the Argument we're going to change. + /// @param[in] kbatch The kbatch value. + /// virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0; + //---------------------------------------------------------------------------------------------- + /// @brief Sets the k batch size. + /// + /// @param p_arg Pointer to the Argument we're going to change. + /// @param[in] kbatch The kbatch value. + /// + virtual void SetKBatch(BaseArgument* p_arg, index_t kbatch) const + { + this->SetKBatchSize(p_arg, kbatch); + }; }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp index c1030f31ccc187192aa5e552fab79e41e785861d..712fbfd9e9ac233d2cbc7c39b8a9ab286220f4fc 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp @@ -3,83 +3,20 @@ #pragma once -#include -#include -#include -#include - #include "device_grouped_gemm.hpp" namespace ck { namespace tensor_operation { namespace device { +/// @brief Grouped GEMM kernel using output Tile Looping algorithm /// -/// @brief Structure representing single GEMM problem arguments. -/// -/// The pointer to the vector of those structures is passed to the GroupedGEMM entry -/// point kernel. -/// -/// @tparam NumDTensor The number of D input tensors. +/// @par This kernel does not require any knowledge about input data sizes (GEMM M/N/K) +/// It requires only the number of groups to launch. Other information like +/// data pointers and GEMM sizes, packed into gemm kernel args may be all dynamic +/// (known only at kernel run-time). /// -template -struct GroupedGemmTileLoopKernelArguments -{ - __host__ __device__ - GroupedGemmTileLoopKernelArguments(const void* p_a_grid_, - const void* p_b_grid_, - std::array p_ds_grid_, - void* p_e_grid_, - index_t M_, - index_t N_, - index_t K_, - index_t StrideA_, - index_t StrideB_, - std::array StrideDs_, - index_t StrideE_) - : p_a_grid{p_a_grid_}, - p_b_grid{p_b_grid_}, - p_ds_grid{p_ds_grid_}, - p_e_grid{p_e_grid_}, - M{M_}, - N{N_}, - K{K_}, - StrideA{StrideA_}, - StrideB{StrideB_}, - StrideDs{StrideDs_}, - StrideE{StrideE_} - { - } - - const void* p_a_grid; - const void* p_b_grid; - std::array p_ds_grid; - void* p_e_grid; - index_t M; - index_t N; - index_t K; - index_t StrideA; - index_t StrideB; - std::array StrideDs; - index_t StrideE; - - void Print() const - { - std::stringstream str; - for(auto sd : StrideDs) - str << sd << ","; - - std::cout << "arg {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SE:" << StrideE << ", " - << "SDs: {" << str.str() << "}" - << "}" << std::endl; - } -}; +/// @note This kernel does not support SplitK. template { - //---------------------------------------------------------------------------------------------- - /// @brief Sets the device kernel arguments pointer. - /// - /// @param p_arg The pointer to the Argument we're going to update. - /// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel - /// arguments. - /// - virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const = 0; - - //---------------------------------------------------------------------------------------------- - /// @brief Gets the device kernel argument size. - /// - /// @param[in] p_arg The pointer to the Device op Argument. - /// - /// @return The device kernel argument size. - /// - virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0; }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp index 0bb45b18c3e19b2ec5f9347c1e811d8734ee45a9..997dcb75a6faee60912f23d55f77c54a60cd2ff4 100644 --- a/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -29,6 +29,7 @@ enum struct GemmSpecialization MNKOPadding, }; +#ifndef CK_CODE_GEN_RTC inline std::string getGemmSpecializationString(const GemmSpecialization& s) { switch(s) @@ -52,6 +53,7 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s) default: return "Unrecognized specialization!"; } } +#endif } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 180e32c8b6b8ee41577b3f9700614990a102b2ee..00518b369f4193327806d062e6cd9662ebb6489a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -3,11 +3,17 @@ #pragma once +#ifndef CK_CODE_GEN_RTC #include #include #include #include #include +#include + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#endif #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -15,15 +21,12 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" -#include "ck/host_utility/io.hpp" namespace ck { namespace tensor_operation { @@ -91,8 +94,7 @@ __device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle( const Block2ETileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -259,8 +261,13 @@ __global__ void } // namespace +#ifdef CK_CODE_GEN_RTC +template +using is_tuple = decltype(ck::declval().IsTuple()); +#else template using is_tuple = decltype(std::declval().IsTuple()); +#endif // // @brief Device Convolution operation. @@ -429,8 +436,8 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // If we are using multiAB and one of the template datatype parameters is not a tuple, convert // it to it - using GemmADataType = std::conditional_t, ADataType>; - using GemmBDataType = std::conditional_t, BDataType>; + using GemmADataType = ck::conditional_t, ADataType>; + using GemmBDataType = ck::conditional_t, BDataType>; #define GridwiseGemmTemplateParameters \ GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ @@ -449,15 +456,13 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle CDEBlockTransferScalarPerVector_NPerBlock, LoopSched // Use appropriate gridwise gemm using GridwiseGemm = - std::conditional_t, - GridwiseGemmMultipleD_xdl_cshuffle>; + ck::conditional_t, + GridwiseGemmMultipleD_xdl_cshuffle>; // If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers. - using APointers = - std::conditional_t&, const void*>; - using BPointers = - std::conditional_t&, const void*>; + using APointers = ck::conditional_t&, const void*>; + using BPointers = ck::conditional_t&, const void*>; // Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not // in initializer list what is required for single const pointer). using AGridPointer = remove_cvref_t< @@ -812,7 +817,6 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle static_for<0, NumDTensor, 1>{}([&](auto i) { using DLayout = remove_cvref_t>; - // FIXME: layout if constexpr(is_same_v || is_same_v || is_same_v || is_same_v || @@ -965,18 +969,18 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle const BElementwiseOperation& b_element_op, const CDEElementwiseOperation& cde_element_op) { - std::array a_g_n_c_wis_lengths_i32; - std::array a_g_n_c_wis_strides_i32; - std::array b_g_k_c_xs_lengths_i32; - std::array b_g_k_c_xs_strides_i32; - std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; - std::array, NumDTensor> ds_g_n_k_wos_strides_i32; - std::array e_g_n_k_wos_lengths_i32; - std::array e_g_n_k_wos_strides_i32; - std::array conv_filter_strides_i32; - std::array conv_filter_dilations_i32; - std::array input_left_pads_i32; - std::array input_right_pads_i32; + ck::Array a_g_n_c_wis_lengths_i32; + ck::Array a_g_n_c_wis_strides_i32; + ck::Array b_g_k_c_xs_lengths_i32; + ck::Array b_g_k_c_xs_strides_i32; + ck::Array, NumDTensor> ds_g_n_k_wos_lengths_i32; + ck::Array, NumDTensor> ds_g_n_k_wos_strides_i32; + ck::Array e_g_n_k_wos_lengths_i32; + ck::Array e_g_n_k_wos_strides_i32; + ck::Array conv_filter_strides_i32; + ck::Array conv_filter_dilations_i32; + ck::Array input_left_pads_i32; + ck::Array input_right_pads_i32; array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp index 64aa398d531e634c54c3d55d591789f72943bd0e..d53fbca4eaf4b62ec419a37071881ee192130b48 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp @@ -56,8 +56,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp index d06eab1264684b595dc87dc1cf91a9f2b44a5056..25a9d7f96dea69a339b977c2cc7a9e6a973317c7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp @@ -74,8 +74,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp index e950169ccfe58887351c4bfbc82af85a061287cf..985752796bfa5f5a62bf88c881840af0fa4e95ee 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp @@ -60,8 +60,7 @@ __global__ void const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -108,7 +107,7 @@ __global__ void ignore = block_2_ctile_map; ignore = batch_count; ignore = compute_base_ptr_of_batch; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } // Computes C = A * B0 * B1 diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp index d6b92bc97a8d3eed0d0ce50ad26a9169c498671a..630f143260495a9aeb11f0c764a03ef21b348756 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp @@ -83,8 +83,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp index eb2eb054721268c5bd05763af68d57fa59e79598..302001642d507d2e47eb0e6a7b1d886b236b9da0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp @@ -70,8 +70,7 @@ __global__ void const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp index 314ecdf76e59da5ccc9e59a67691968426fa4b1f..5f5bea4f8635bb22742789fdb2055d37a8611dd3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -41,12 +41,15 @@ __global__ void __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t g_idx = blockIdx.z % karg.Batch; + const index_t k_idx = blockIdx.z / karg.Batch; const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); + // populate pointer, desc for Ds static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { // D pointer @@ -54,8 +57,8 @@ __global__ void }); GridwiseGemm::template Run( - karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, + karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, karg.p_ds_grid, karg.p_c_grid + c_batch_offset, p_shared, @@ -87,12 +90,15 @@ __global__ void __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t g_idx = blockIdx.z % karg.Batch; + const index_t k_idx = blockIdx.z / karg.Batch; const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); + // populate pointer, desc for Ds static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { // D pointer @@ -100,8 +106,8 @@ __global__ void }); GridwiseGemm::template Run_2Lds( - karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, + karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, karg.p_ds_grid, karg.p_c_grid + c_batch_offset, p_shared_0, @@ -303,7 +309,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 index_t Batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, - CElementwiseOperation c_element_op_) + CElementwiseOperation c_element_op_, + index_t KBatch_) : GridwiseGemm::Argument{p_a_grid_, p_b_grid_, p_ds_grid_, @@ -315,7 +322,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 StrideB_, StrideDs_, StrideE_, - 1, + KBatch_, a_element_op_, b_element_op_, c_element_op_}, @@ -336,13 +343,14 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 arg.Print(); } - if(!GridwiseGemm::CheckValidity(arg) || arg.KBatch > 1) + if(!GridwiseGemm::CheckValidity(arg)) { throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } index_t gdx, gdy, gdz; - std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.Batch); + std::tie(gdx, gdy, gdz) = + GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.Batch * arg.KBatch); float ave_time = 0; @@ -387,10 +395,11 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 rotating_mem.Next(); // clear c mem if(arg_.KBatch > 1) - hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, - 0, - arg_.M * arg_.N * sizeof(CDataType), - stream_config.stream_id_)); + hipGetErrorString( + hipMemsetAsync(arg_.p_c_grid, + 0, + arg.Batch * arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); }; ave_time = ck::utility::launch_and_time_kernel_with_preprocess( @@ -889,7 +898,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 index_t BatchStrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) + CElementwiseOperation c_element_op, + index_t KBatch = 1) { return Argument{static_cast(p_a), static_cast(p_b), @@ -909,7 +919,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 Batch, a_element_op, b_element_op, - c_element_op}; + c_element_op, + KBatch}; } static auto MakeInvoker() { return Invoker{}; } @@ -934,7 +945,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 index_t BatchStrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) override + CElementwiseOperation c_element_op, + index_t KBatch = 1) override { return std::make_unique(static_cast(p_a), static_cast(p_b), @@ -954,7 +966,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 Batch, a_element_op, b_element_op, - c_element_op); + c_element_op, + KBatch); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp index 34b1d503afe78e324d43f5fb7df6531809756e99..30ae72a63e98dfb8f86b7ccf45f620b4ba3633a8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -59,8 +59,7 @@ __global__ void const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index e178b8f5252781ead149f5d2b78f8fc53125a3af..2662e5c360b2d8e16082e62e84b9779ef913b00d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -67,8 +67,7 @@ __global__ void const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const C0MatrixMask c0_matrix_mask) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -127,7 +126,7 @@ __global__ void ignore = batch_count; ignore = compute_base_ptr_of_batch; ignore = c0_matrix_mask; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } // Computes C = A * B0 * B1 diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index 9af1a447814d530a06b5de7083ad691c03e6962e..bfbcebd7c8c62721b8bb32e9934c0ca7e28df8a0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -62,8 +62,7 @@ __global__ void const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const C0MatrixMask c0_matrix_mask) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -112,7 +111,7 @@ __global__ void ignore = batch_count; ignore = compute_base_ptr_of_batch; ignore = c0_matrix_mask; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } // Computes C = A * B0 * B1 diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp index 6be2ffbdd781d27cc43845b7da311bc8a929e42b..494524b6f0588c33687ffd4a434977a70a173a10 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp @@ -52,8 +52,7 @@ __global__ void #endif kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / karg.Batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp new file mode 100644 index 0000000000000000000000000000000000000000..963f0edd08813251f2ecf06e2bc4e847f06827e2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp @@ -0,0 +1,1007 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_batched_gemm_b_scale_xdl_cshuffle_v3(BatchedGemmArg karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t g_idx = blockIdx.z % karg.Batch; + const index_t k_idx = blockIdx.z / karg.Batch; + + const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + const auto b_scale_batch_offset = karg.compute_ptr_offset_of_batch.GetSacleBPtrOffset(g_idx); + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); + + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + c_batch_offset + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset, + p_shared, + karg); + +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds(BatchedGemmArg karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t g_idx = blockIdx.z % karg.Batch; + const index_t k_idx = blockIdx.z / karg.Batch; + + const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + const auto b_scale_batch_offset = karg.compute_ptr_offset_of_batch.GetSacleBPtrOffset(g_idx); + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + c_batch_offset + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset, + p_shared_0, + p_shared_1, + karg); + +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale + : public DeviceBatchedGemmV2BScale +{ + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t BatchStrideScaleB) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideC_(BatchStrideC), + BatchStrideScaleB_(BatchStrideScaleB) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_) / BPackedSize; + } + + __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + __host__ __device__ constexpr long_index_t GetSacleBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideScaleB_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideC_; + index_t BatchStrideScaleB_; + }; + + struct Argument : public GridwiseGemm::Argument + { + index_t Batch; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch; + + Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t StrideScaleB_, + index_t BatchStrideA_, + index_t BatchStrideB_, + index_t BatchStrideC_, + index_t BatchStrideScaleB_, + const BScaleDataType* p_b_scale_grid_, + index_t Batch_, + index_t KBatch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : GridwiseGemm::Argument(p_a_grid_, + p_b_grid_, + p_c_grid_, + M_, + N_, + K_, + StrideA_, + StrideB_, + StrideC_, + StrideScaleB_, + p_b_scale_grid_, + KBatch_, // KBatch + a_element_op_, + b_element_op_, + c_element_op_), + Batch(Batch_), + compute_ptr_offset_of_batch( + BatchStrideA_, BatchStrideB_, BatchStrideC_, BatchStrideScaleB_) + { + } + }; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = + GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.Batch * arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + + ck::utility::RotatingMemWrapper rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave + ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 && + MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2) + ? 2 + : 1 + : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + index_t GetKPerBlock() override { return KPerBlock; } + + bool GetPermuteB() override { return PermuteB; } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideScaleB, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t BatchStrideScaleB, + const BScaleDataType* p_b_scale, + index_t Batch, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideScaleB, + BatchStrideA, + BatchStrideB, + BatchStrideC, + BatchStrideScaleB, + p_b_scale, + Batch, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideScaleB, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t BatchStrideScaleB, + const void* p_b_scale, + index_t Batch, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideScaleB, + BatchStrideA, + BatchStrideB, + BatchStrideC, + BatchStrideScaleB, + static_cast(p_b_scale), + Batch, + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"<(p_as_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp index 6e6921351356740f899daab6b6343a46008072fc..8aa20f7ad476e0292a482371d5a2dd27ed5ae35f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -55,8 +55,7 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / num_batches); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -97,7 +96,7 @@ __global__ void ignore = b_element_op; ignore = c_element_op; ignore = block_2_ctile_map; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } // specialization for #D conv: in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k] diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp index 811f1ae9396de6f9166c4d6c78dcc9c757e42996..b9467ac1945cad037d30083e09d124ceac50bfbd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp @@ -50,9 +50,8 @@ __global__ void const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ - defined(__gfx12__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx9__) || \ + defined(__gfx103__) || defined(__gfx11__) || defined(__gfx12__)) constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp index eaafd7d5c5eb05b69447ba1b6157da1cab25e230..47fb630ea9037432b506cb30ce5699b645e22365 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp @@ -63,8 +63,7 @@ __global__ void const Block2ETileMap block_2_etile_map, index_t NRaw) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()]; GridwiseGemmWelford::template Run( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index bb2db930c8ea7b1ed34b17a7cde56a2c7c6daafc..c048e7249c305fba0d2cfcfab1a0e78199e960f5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -60,8 +60,7 @@ __global__ void const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index 77ed9625c5d6429967ff9636ce9829a28c3b7345..e6466a487b11826929faf2454796bb457b2710ef 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -52,8 +52,7 @@ __global__ void e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp index 452063156e2ecfa983ddc99d6896adf18505d1c2..26be5cfc613a1fc9a18341709b4cc17e532b19ad 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp @@ -131,6 +131,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2 0) { arg.Print(); @@ -147,26 +148,27 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2(arg.p_workspace_) + + arg.block_2_ctile_map_streamk.get_workspace_size_for_acc( + sizeof(GemmAccDataType)); + auto preprocess = [&]() { + hipMemsetAsync( + workspace_semaphore, + 0, + // sizeof(uint32_t), + arg.block_2_ctile_map_streamk.get_workspace_size_for_semaphore(), + stream_config.stream_id_); + }; + + ave_time = launch_and_time_kernel_with_preprocess( + stream_config, preprocess, kernel, grid_dim, dim3(BlockSize), 0, arg); + } } }; @@ -211,14 +236,12 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2; - Run(kernel); - } + const auto kernel = kernel_gemm_xdl_cshuffle_v3; + + Run(kernel); } // Tail number could be One to Seven else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) @@ -340,53 +363,49 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2; - Run(kernel); - } - else - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3_2lds; - Run(kernel); - } + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); } } else { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - else - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); } } } @@ -396,14 +415,11 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2; - Run(kernel); - } + const auto kernel = kernel_gemm_xdl_cshuffle_v3; + Run(kernel); } } @@ -418,6 +434,29 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2(pArg); + if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + return p_arg->block_2_ctile_map_streamk.get_workspace_size(sizeof(GemmAccDataType)); + } + else + { + return 0; + } + } + + void SetWorkSpacePointer(BaseArgument* pArg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + Argument* pArg_ = dynamic_cast(pArg); + + pArg_->p_workspace_ = p_workspace; + } + static constexpr bool IsValidCompilationParameter() { // TODO: properly implement this check @@ -430,7 +469,11 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2 && + arg.Streamk_sel > 0) + { + return false; + } if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding || @@ -464,8 +507,205 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2; + calculate_grid_size(kernel); + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + calculate_grid_size(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + calculate_grid_size(kernel); + } + } + else + { + + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + + const auto kernel = kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + + return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, streamk_sel, Grid_size}; } static auto MakeInvoker() { return Invoker{}; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp index 4489b2e5ce56bab4e6be7549bf22f500e7861a69..1c1449665042827023f1495df637b714ecc508a2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp @@ -64,7 +64,9 @@ template + typename ComputeTypeB = ComputeTypeA, + bool PermuteA = false, + bool PermuteB = false> struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2; + ComputeTypeB, + PermuteA, + PermuteB>; using Argument = typename GridwiseGemm::Argument; @@ -134,6 +138,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 0) { arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); } if(!GridwiseGemm::CheckValidity(arg)) @@ -633,6 +638,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2(p_arg)); } + index_t GetKPerBlock() override { return KPerBlock; } + + bool GetPermuteA() override { return PermuteA; } + bool GetPermuteB() override { return PermuteB; } + static auto MakeArgument(const ADataType* p_a, const BDataType* p_b, CDataType* p_c, @@ -724,11 +734,14 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 +#include + +#include "ck/utility/common_header.hpp" + +#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale +{ + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + + ck::utility::RotatingMemWrapper rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave + ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 && + MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2) + ? 2 + : 1 + : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + } + } + else + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + index_t GetKPerBlock() override { return KPerBlock; } + + bool GetPermuteB() override { return PermuteB; } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideScaleB, + const BScaleDataType* p_b_scale, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideScaleB, + p_b_scale, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideScaleB, + const void* p_b_scale, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideScaleB, + static_cast(p_b_scale), + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"<(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp index cc022b89c5ed472758300a1cd4e19cc5b2924f09..1cf58fec258123dee7957534061e14f0f056479a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp @@ -37,8 +37,7 @@ __global__ void const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp index 3fb047f207b723563d50951b61a0d18c31d2faf0..359711e5c41a11c00d7a5fb11b1d1c948d5b4221 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -106,89 +106,35 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle static constexpr auto I3 = Number<3>{}; static constexpr index_t KPerBlock = K0PerBlock * K1; - static constexpr auto transform_conv_to_gemm = - TransformConvBwdDataToGemm_v1{}; - - static auto GetDummyABDsEGridDescriptor() - { - const std::array dummy_tensor_lengths = {1}; - const std::array dummy_tensor_strides = {1}; - const std::array dummy_spatial_lengths = {1}; - - const auto a_grid_desc_ak0_m_ak1 = - transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1( - dummy_tensor_lengths, - dummy_tensor_strides, - dummy_tensor_lengths, - dummy_tensor_strides, - dummy_tensor_lengths, - dummy_tensor_strides, - dummy_spatial_lengths, - dummy_spatial_lengths, - dummy_spatial_lengths, - dummy_spatial_lengths, - dummy_spatial_lengths); - - const auto b_grid_desc_bk0_n_bk1 = - transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1( - dummy_tensor_lengths, - dummy_tensor_strides, - dummy_tensor_lengths, - dummy_tensor_strides, - dummy_tensor_lengths, - dummy_tensor_strides, - dummy_spatial_lengths, - dummy_spatial_lengths, - dummy_spatial_lengths, - dummy_spatial_lengths, - dummy_spatial_lengths); - - const auto ds_grid_desc_m_n = generate_tuple( - [&](auto i) { - using DLayout = remove_cvref_t>; - - return transform_conv_to_gemm.template MakeCDescriptor_M_N( - dummy_tensor_lengths, - dummy_tensor_strides, - dummy_tensor_lengths, - dummy_tensor_strides, - dummy_tensor_lengths, - dummy_tensor_strides, - dummy_spatial_lengths, - dummy_spatial_lengths, - dummy_spatial_lengths, - dummy_spatial_lengths, - dummy_spatial_lengths); - }, - Number{}); - - const auto e_grid_desc_m_n = - transform_conv_to_gemm.template MakeCDescriptor_M_N(dummy_tensor_lengths, - dummy_tensor_strides, - dummy_tensor_lengths, - dummy_tensor_strides, - dummy_tensor_lengths, - dummy_tensor_strides, - dummy_spatial_lengths, - dummy_spatial_lengths, - dummy_spatial_lengths, - dummy_spatial_lengths, - dummy_spatial_lengths); + using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1; + static auto + GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform) + { + const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1(); + const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1(); + const auto ds_grid_desc_m_n = + generate_tuple([&](auto) { return conv_to_gemm_transform.MakeCDescriptor_M_N(); }, + Number{}); + const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N(); return make_tuple( a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n); } // desc - using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor()); + constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform; + using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform)); using AGridDesc_AK0_M_AK1 = remove_cvref_t>; using BGridDesc_BK0_N_BK1 = remove_cvref_t>; @@ -270,7 +216,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle const std::array& b_g_k_c_xs_lengths, const std::array& b_g_k_c_xs_strides, const std::array, NumDTensor>& - ds_g_n_c_wis_lengths, + /*ds_g_n_c_wis_lengths*/, const std::array, NumDTensor>& ds_g_n_c_wis_strides, const std::array& e_g_n_c_wis_lengths, @@ -291,15 +237,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle b_element_op_{b_element_op}, cde_element_op_{cde_element_op}, a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths}, - a_g_n_k_wos_strides_{a_g_n_k_wos_strides}, b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, - b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, - ds_g_n_c_wis_lengths_{ds_g_n_c_wis_lengths}, - ds_g_n_c_wis_strides_{ds_g_n_c_wis_strides}, - e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths}, - e_g_n_c_wis_strides_{e_g_n_c_wis_strides}, conv_filter_strides_{conv_filter_strides}, - conv_filter_dilations_{conv_filter_dilations}, input_left_pads_{input_left_pads}, input_right_pads_{input_right_pads} { @@ -382,68 +321,47 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle tildes = {i_ztilde, i_ytilde, i_xtilde}; } + ConvToGemmBwdDataTransform conv_to_gemm_transform_{a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes}; + const auto a_grid_desc_ak0_m_ak1 = - transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1( - a_g_n_k_wos_lengths, - a_g_n_k_wos_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_c_wis_lengths, - e_g_n_c_wis_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - tildes); + conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1(); const auto b_grid_desc_bk0_n_bk1 = - transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1( - a_g_n_k_wos_lengths, - a_g_n_k_wos_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_c_wis_lengths, - e_g_n_c_wis_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - tildes); + conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1(); DsGridDesc_M_N ds_grid_desc_m_n; // populate Ds desc static_for<0, NumDTensor, 1>{}([&](auto i) { using DLayout = remove_cvref_t>; - - ds_grid_desc_m_n(i) = - transform_conv_to_gemm.template MakeCDescriptor_M_N( - a_g_n_k_wos_lengths, - a_g_n_k_wos_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - ds_g_n_c_wis_lengths[i], - ds_g_n_c_wis_strides[i], - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - tildes); - }); - - const auto e_grid_desc_m_n = - transform_conv_to_gemm.template MakeCDescriptor_M_N( + static_assert(is_same_v); + ConvToGemmBwdDataTransform conv_to_gemm_transform_d{ a_g_n_k_wos_lengths, a_g_n_k_wos_strides, b_g_k_c_xs_lengths, b_g_k_c_xs_strides, e_g_n_c_wis_lengths, - e_g_n_c_wis_strides, + ds_g_n_c_wis_strides[i], conv_filter_strides, conv_filter_dilations, input_left_pads, input_right_pads, - tildes); + tildes}; + + ds_grid_desc_m_n(i) = conv_to_gemm_transform_d.MakeCDescriptor_M_N(); + }); + + const auto e_grid_desc_m_n = conv_to_gemm_transform_.MakeCDescriptor_M_N(); // for check validity ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n); @@ -522,17 +440,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle BElementwiseOp b_element_op_; CDEElementwiseOp cde_element_op_; - // for checking IsSupportedArgument() std::array a_g_n_k_wos_lengths_; - std::array a_g_n_k_wos_strides_; std::array b_g_k_c_xs_lengths_; - std::array b_g_k_c_xs_strides_; - std::array, NumDTensor> ds_g_n_c_wis_lengths_; - std::array, NumDTensor> ds_g_n_c_wis_strides_; - std::array e_g_n_c_wis_lengths_; - std::array e_g_n_c_wis_strides_; std::array conv_filter_strides_; - std::array conv_filter_dilations_; std::array input_left_pads_; std::array input_right_pads_; }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index b544c925e1394ac05012ad37e09d4611e6da2606..99bd3be15df3688c2b7688d2acde6c74bbf03fff 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -54,15 +54,16 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS @@ -73,10 +74,9 @@ __global__ void const ABDataType* __restrict__ p_b_grid, DsPointer p_ds_grid, EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const index_t batch_count, + const AElementwiseOp a_element_op, + const BElementwiseOp b_element_op, + const CDEElementwiseOp cde_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock @@ -84,24 +84,28 @@ __global__ void const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_, const Block2ETileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const ComputePtrOffsetOfN compute_ptr_offset_of_n) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); - const long_index_t a_batch_offset = amd_wave_read_first_lane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = amd_wave_read_first_lane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = amd_wave_read_first_lane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; DsPointer p_ds_grid_grp; @@ -112,10 +116,10 @@ __global__ void static_for<0, NumDTensor, 1>{}( [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - GridwiseGemm::template Run(p_a_grid + a_batch_offset, + GridwiseGemm::template Run(p_a_grid + a_batch_offset + a_n_offset, p_b_grid + b_batch_offset, p_ds_grid_grp, - p_e_grid + e_batch_offset, + p_e_grid + e_batch_offset + e_n_offset, p_shared, a_element_op, b_element_op, @@ -130,7 +134,6 @@ __global__ void ignore = p_b_grid; ignore = p_ds_grid; ignore = p_e_grid; - ignore = batch_count; ignore = a_grid_desc_ak0_m_ak1; ignore = b_grid_desc_bk0_n_bk1; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; @@ -139,6 +142,7 @@ __global__ void ignore = b_element_op; ignore = cde_element_op; ignore = compute_ptr_offset_of_batch; + ignore = compute_ptr_offset_of_n; ignore = block_2_ctile_map; #endif } @@ -233,82 +237,54 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - static constexpr auto transform_conv_to_gemm = - TransformConvBwdDataToGemm_v1{}; - - static auto GetDummyABDsEGridDescriptor() + using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1; + + static auto + GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform) { - const std::array dummy_tensor_lengths = {1}; - const std::array dummy_tensor_strides = {1}; - const std::array dummy_spatial_lengths = {1}; - - const auto a_grid_desc_ak0_m_ak1 = - transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1( - dummy_tensor_lengths, - dummy_tensor_strides, - dummy_tensor_lengths, - dummy_tensor_strides, - dummy_tensor_lengths, - dummy_tensor_strides, - dummy_spatial_lengths, - dummy_spatial_lengths, - dummy_spatial_lengths, - dummy_spatial_lengths, - dummy_spatial_lengths); - - const auto b_grid_desc_bk0_n_bk1 = - transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1( - dummy_tensor_lengths, - dummy_tensor_strides, - dummy_tensor_lengths, - dummy_tensor_strides, - dummy_tensor_lengths, - dummy_tensor_strides, - dummy_spatial_lengths, - dummy_spatial_lengths, - dummy_spatial_lengths, - dummy_spatial_lengths, - dummy_spatial_lengths); + const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1(); + + const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1(); const auto ds_grid_desc_m_n = generate_tuple( [&](auto i) { - using DLayout = remove_cvref_t>; - - return transform_conv_to_gemm.template MakeCDescriptor_M_N( - dummy_tensor_lengths, - dummy_tensor_strides, - dummy_tensor_lengths, - dummy_tensor_strides, - dummy_tensor_lengths, - dummy_tensor_strides, - dummy_spatial_lengths, - dummy_spatial_lengths, - dummy_spatial_lengths, - dummy_spatial_lengths, - dummy_spatial_lengths); + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + using ConvToGemmBwdDataTransformD = + TransformConvBwdDataToGemm_v1; + return ConvToGemmBwdDataTransformD{}.MakeCDescriptor_M_N(); }, Number{}); - const auto e_grid_desc_m_n = - transform_conv_to_gemm.template MakeCDescriptor_M_N(dummy_tensor_lengths, - dummy_tensor_strides, - dummy_tensor_lengths, - dummy_tensor_strides, - dummy_tensor_lengths, - dummy_tensor_strides, - dummy_spatial_lengths, - dummy_spatial_lengths, - dummy_spatial_lengths, - dummy_spatial_lengths, - dummy_spatial_lengths); + const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N(); return make_tuple( a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n); @@ -377,7 +353,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } // desc - using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor()); + constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform; + using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform)); using AGridDesc_AK0_M_AK1 = remove_cvref_t>; using BGridDesc_BK0_N_BK1 = remove_cvref_t>; @@ -431,15 +408,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 b_element_op_{b_element_op}, cde_element_op_{cde_element_op}, a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths}, - a_g_n_k_wos_strides_{a_g_n_k_wos_strides}, b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, - b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, - ds_g_n_c_wis_lengths_{ds_g_n_c_wis_lengths}, - ds_g_n_c_wis_strides_{ds_g_n_c_wis_strides}, - e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths}, - e_g_n_c_wis_strides_{e_g_n_c_wis_strides}, conv_filter_strides_{conv_filter_strides}, - conv_filter_dilations_{conv_filter_dilations}, input_left_pads_{input_left_pads}, input_right_pads_{input_right_pads} { @@ -450,11 +420,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 p_ds_grid_(i) = static_cast(p_ds[i]); }); - // A/B/Ds/E Batch Stride - compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; - compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; - compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0]; - static_for<0, NumDTensor, 1>{}([&](auto i) { compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0]; }); @@ -526,68 +491,65 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 throw std::runtime_error("wrong! only implemented for 2D and 3D now"); } + ConvToGemmBwdDataTransform conv_to_gemm_transform_{a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes}; + + conv_N_per_block_ = conv_to_gemm_transform_.N_; + const auto a_grid_desc_ak0_m_ak1 = - transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1( - a_g_n_k_wos_lengths, - a_g_n_k_wos_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_c_wis_lengths, - e_g_n_c_wis_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - tildes); + conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1(); const auto b_grid_desc_bk0_n_bk1 = - transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1( - a_g_n_k_wos_lengths, - a_g_n_k_wos_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_c_wis_lengths, - e_g_n_c_wis_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - tildes); + conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1(); DsGridDesc_M_N ds_grid_desc_m_n; // populate Ds desc static_for<0, NumDTensor, 1>{}([&](auto i) { - using DLayout = remove_cvref_t>; - - ds_grid_desc_m_n(i) = - transform_conv_to_gemm.template MakeCDescriptor_M_N( - a_g_n_k_wos_lengths, - a_g_n_k_wos_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - ds_g_n_c_wis_lengths[i], - ds_g_n_c_wis_strides[i], - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - tildes); - }); - - const auto e_grid_desc_m_n = - transform_conv_to_gemm.template MakeCDescriptor_M_N( + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + using ConvToGemmBwdDataTransformD = + TransformConvBwdDataToGemm_v1; + ConvToGemmBwdDataTransformD conv_to_gemm_transform_d{ a_g_n_k_wos_lengths, a_g_n_k_wos_strides, b_g_k_c_xs_lengths, b_g_k_c_xs_strides, - e_g_n_c_wis_lengths, - e_g_n_c_wis_strides, + ds_g_n_c_wis_lengths[i], + ds_g_n_c_wis_strides[i], conv_filter_strides, conv_filter_dilations, input_left_pads, input_right_pads, - tildes); + tildes}; + + ds_grid_desc_m_n(i) = conv_to_gemm_transform_d.MakeCDescriptor_M_N(); + }); + + const auto e_grid_desc_m_n = conv_to_gemm_transform_.MakeCDescriptor_M_N(); // desc for problem definition const auto a_grid_desc_m_k = @@ -628,6 +590,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } } } + // A/B/Ds/E Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0]; + + compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_k_wos_strides[1] * conv_N_per_block_; + compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_c_wis_strides[1] * conv_N_per_block_; } void Print() const @@ -660,6 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // tensor descriptor for problem definition index_t num_group_; + index_t conv_N_per_block_; std::vector a_grid_desc_m_k_container_; std::vector b_grid_desc_n_k_container_; std::vector ds_grid_desc_m_n_container_; @@ -678,23 +648,16 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // for computing batch offset ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_; // element-wise op AElementwiseOp a_element_op_; BElementwiseOp b_element_op_; CDEElementwiseOp cde_element_op_; - // for checking IsSupportedArgument() std::array a_g_n_k_wos_lengths_; - std::array a_g_n_k_wos_strides_; std::array b_g_k_c_xs_lengths_; - std::array b_g_k_c_xs_strides_; - std::array, NumDTensor> ds_g_n_c_wis_lengths_; - std::array, NumDTensor> ds_g_n_c_wis_strides_; - std::array e_g_n_c_wis_lengths_; - std::array e_g_n_c_wis_strides_; std::array conv_filter_strides_; - std::array conv_filter_dilations_; std::array input_left_pads_; std::array input_right_pads_; }; @@ -711,8 +674,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 arg.Print(); } - float ave_time = 0; + const index_t gdy = arg.num_group_; + const index_t num_workgroups_per_Conv_N = + arg.a_g_n_k_wos_lengths_[I1] / arg.conv_N_per_block_; + const index_t gdz = num_workgroups_per_Conv_N; + float ave_time = 0; for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i], @@ -724,9 +691,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 throw std::runtime_error("wrong! device_op has invalid setting"); } - const index_t grid_size = arg.block_2_etile_map_container_[i].CalculateGridSize( - arg.e_grid_desc_m_n_container_[i]) * - arg.num_group_; + const index_t gdx = arg.block_2_etile_map_container_[i].CalculateGridSize( + arg.e_grid_desc_m_n_container_[i]); const auto GemmK = arg.a_grid_desc_m_k_container_[i].GetLength(I1); @@ -747,12 +713,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, Block2ETileMap, ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, has_main_loop>; return launch_and_time_kernel( stream_config, kernel, - dim3(grid_size), + dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg.p_a_grid_, @@ -762,13 +729,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 arg.a_element_op_, arg.b_element_op_, arg.cde_element_op_, - arg.a_g_n_k_wos_lengths_[0], // Group count arg.a_grid_desc_ak0_m_ak1_container_[i], arg.b_grid_desc_bk0_n_bk1_container_[i], arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], arg.block_2_etile_map_container_[i], - arg.compute_ptr_offset_of_batch_); + arg.compute_ptr_offset_of_batch_, + arg.compute_ptr_offset_of_n_); }; if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index a7df1c9d570c81af44f226709d1334f7f254810c..57c4b1a5cf46c411084fcbe74b5b84c8f89f4ace 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -60,8 +60,7 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -103,7 +102,7 @@ __global__ void compute_ptr_offset_of_batch.GetAPtrOffset(0); compute_ptr_offset_of_batch.GetBPtrOffset(0); compute_ptr_offset_of_batch.GetCPtrOffset(0); -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } template 1)) { return false; } - if(!(arg.Conv_C_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideB_ == 1)) + if(!(arg.Conv_C_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideB_ == 1 && + NumGroupsToMerge > 1)) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 996107343d510f6b0c5e418881fa90ea5b60e4b2..ef87bb52ae37819484fe22759f59ae30425bac27 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -584,6 +584,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle { return false; } + if(!is_bf16_atomic_supported() && std::is_same_v) + { + return false; + } if constexpr(NDimSpatial == 1) { if constexpr(!is_GNWC_GKXC_GNWK()) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index f21a45938f5079a4093c6dfddf41e3a7390d5e45..02ca8f42e496ccc69f1dd5a5e34687c47e507d7a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -9,6 +9,7 @@ #include #include +#include "ck/library/utility/numeric.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" @@ -98,8 +99,7 @@ __global__ void const ComputePtrOffsetOfG compute_ptr_offset_of_groups, const ComputePtrOffsetOfN compute_ptr_offset_of_n) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); @@ -121,19 +121,6 @@ __global__ void static_for<0, NumDTensor, 1>{}( [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; }); - if constexpr(is_same_v) - { - a_element_op.InitUnaryOpPtrOnDevice(); - } - if constexpr(is_same_v) - { - b_element_op.InitUnaryOpPtrOnDevice(); - } - if constexpr(is_same_v) - { - cde_element_op.InitUnaryOpPtrOnDevice(); - } - if constexpr(isMultiA || isMultiB) { AsPointer p_as_grid_grp; @@ -225,9 +212,13 @@ __global__ void } } // namespace - +#ifdef CK_CODE_GEN_RTC +template +using is_tuple = decltype(ck::declval().IsTuple()); +#else template using is_tuple = decltype(std::declval().IsTuple()); +#endif // // @brief Device Convolution operation. diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 589a0daa99d4070ebf72d64961e239cfdc7c5488..9363d7ecb9a0f411412780db91a46c19e7bbbbfa 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -9,6 +9,7 @@ #include #include +#include "ck/library/utility/numeric.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" @@ -117,7 +118,7 @@ __global__ void c_grid_desc_mblock_mperblock_nblock_nperblock); #else ignore = karg; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } template = false> struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage - : public DeviceGroupedGemmMultipleDSplitK + : public DeviceGroupedGemmSplitK { using DeviceOp = DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage; @@ -530,7 +529,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage index_t skipped_group_count_; index_t grid_size_; // Pointer to device memory with GEMM kernel arguments. - const void* p_dev_gemm_args_; + void* p_dev_gemm_kargs_; AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; @@ -566,7 +565,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage /// @return The average kernel execution time (if time measurement is enabled.) /// float Run(const Argument& arg, - const void* dev_gemm_args, + void* dev_gemm_args, void* dev_gemm_workspace, const StreamConfig& stream_config = StreamConfig{}) { @@ -621,7 +620,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage /// float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - if(arg.p_dev_gemm_args_ == nullptr) + if(arg.p_dev_gemm_kargs_ == nullptr) { std::ostringstream err; err << "The gemm arguments device buffer is not allocated!" @@ -637,7 +636,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage throw std::runtime_error(err.str()); } - return Run(arg, arg.p_dev_gemm_args_, arg.p_workspace_, stream_config); + return Run(arg, arg.p_dev_gemm_kargs_, arg.p_workspace_, stream_config); } float Run(const BaseArgument* p_arg, @@ -723,7 +722,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage template float DispatchKernel(const Argument& arg, - const void* dev_gemm_args, + void* dev_gemm_kargs, void* dev_gemm_workspace, const StreamConfig& stream_config) const { @@ -746,7 +745,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage return LaunchKernel(gemm_kernel, elementwise_kernel, arg, - dev_gemm_args, + dev_gemm_kargs, dev_gemm_workspace, stream_config); } @@ -755,12 +754,19 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage float LaunchKernel(const KernelFunction& gemm_kernel, const KernelFunction2& elementwise_kernel, const Argument& arg, - const void* dev_gemm_args, + void* dev_gemm_kargs, [[maybe_unused]] void* dev_gemm_workspace, const StreamConfig& stream_config) const { float time{0.f}; + hip_check_error( + hipMemcpyAsync(dev_gemm_kargs, + arg.gemm_kernel_args_.data(), + arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg), + hipMemcpyHostToDevice, + stream_config.stream_id_)); + auto preprocess = [&]() { hip_check_error(hipMemsetAsync( dev_gemm_workspace, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_)); @@ -774,7 +780,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage dim3(arg.grid_size_), dim3(BlockSize), 0, - cast_pointer_to_constant_address_space(dev_gemm_args), + cast_pointer_to_constant_address_space(dev_gemm_kargs), arg.gemm_kernel_args_.size(), arg.a_element_op_, arg.b_element_op_, @@ -930,18 +936,30 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage return str.str(); } - void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const + void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override { - arg.p_dev_gemm_args_ = p_dev_kernel_args; - hip_check_error(hipMemcpy(p_dev_kernel_args, - arg.gemm_kernel_args_.data(), - GetDeviceKernelArgSize(&arg), - hipMemcpyHostToDevice)); + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + arg_ptr->p_dev_gemm_kargs_ = p_dev_kernel_args; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"); } - void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override { - return SetDeviceKernelArgs(*dynamic_cast(p_arg), p_dev_kernel_args); + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->gemm_kernel_args_.size() * sizeof(GemmTransKernelArg); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"); } size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override @@ -974,17 +992,22 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"); } - static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); } - - void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override + [[deprecated]] static void SetKBatchSize(Argument& arg, index_t kbatch) { - return SetKBatchSize(*dynamic_cast(p_arg), kbatch); + arg.UpdateKBatch(kbatch); } - size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override { - return dynamic_cast(p_arg)->gemm_kernel_args_.size() * - sizeof(GemmTransKernelArg); + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->UpdateKBatch(kbatch); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"); } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp index 2884e558cd359494dad4119133348af61765e08a..61058dec2b2565e647cfc87d65cde25e32bf952d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp @@ -20,7 +20,6 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" // stare wywalic -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" namespace ck { @@ -69,8 +68,7 @@ __global__ void const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[shared_size]; @@ -405,7 +403,7 @@ __global__ void ignore = a_element_op; ignore = b_element_op; ignore = cde_element_op; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } template ; - using KernelArguments = GroupedGemmTileLoopKernelArguments; + using KernelArguments = GroupedGemmKernelArgument; using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2; @@ -936,12 +934,31 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop return str.str(); } + void SetDeviceKernelArgs(Argument& arg, + void* p_dev_kernel_args, + const void* p_host_kernel_args) const + { + arg.p_dev_gemm_args_ = p_dev_kernel_args; + hip_check_error(hipMemcpyAsync(p_dev_kernel_args, + p_host_kernel_args, + GetDeviceKernelArgSize(&arg), + hipMemcpyHostToDevice)); + } + + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, + void* p_dev_kernel_args, + const void* p_host_kernel_args) const override + { + return SetDeviceKernelArgs( + *dynamic_cast(p_arg), p_dev_kernel_args, p_host_kernel_args); + } + void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const { arg.p_dev_gemm_args_ = p_dev_kernel_args; } - void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override { return SetDeviceKernelArgs(*dynamic_cast(p_arg), p_dev_kernel_args); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index 871fbd019e7f68348399677e4c6e5a0ae7c22a52..3fb2c5ae86d60d42eee125995e6004c4ce271cdf 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -43,8 +43,7 @@ __global__ void const B1ElementwiseOperation b1_element_op, const CElementwiseOperation c_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); @@ -109,7 +108,7 @@ __global__ void ignore = acc_element_op; ignore = b1_element_op; ignore = c_element_op; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } // Computes C = A * B0 * B1 diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp index 658f3235168f247ef888552a2a07a72f5e6fc0f5..8b40eea56c988ff57c293f51e2418b9097c00df4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -1,6 +1,6 @@ #pragma once // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -38,8 +38,7 @@ __global__ void const BElementwiseOperation b_element_op, const CDEElementwiseOperation c_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); @@ -557,12 +556,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm(p_arg)->group_count_ * sizeof(GemmBiasTransKernelArg); + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + return p_arg_->group_count_ * sizeof(GemmBiasTransKernelArg); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDXdlCShuffle::Argument structure!"); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + return GetWorkSpaceSize(p_arg); + } + + void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override + { + return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args); } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp index ac05a0703fbeea29345e19d0fc55ce8172a150e6..8fe71fb9a2d61fa29d669412296aa20499b7183d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -50,8 +50,7 @@ __global__ void const BElementwiseOperation b_element_op, const CDEElementwiseOperation c_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); @@ -445,6 +444,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK; using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; + // TODO: replace with GroupedGemmKernelArgument struct GemmBiasTransKernelArg { // pointers @@ -900,40 +900,58 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK(p_arg), kernel_args); + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + arg_ptr->grouped_gemm_kernel_args_dev = kernel_args; + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"); } size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override { - auto arg = *dynamic_cast(p_arg); - - return arg.group_count_ * arg.barrier_size_grp_ * sizeof(uint32_t); + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + return arg_ptr->group_count_ * arg_ptr->barrier_size_grp_ * sizeof(uint32_t); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"); } size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override { - auto arg = *dynamic_cast(p_arg); - - return arg.group_count_ * sizeof(GroupedGemmKernelArgument); + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + return arg_ptr->group_count_ * sizeof(GroupedGemmKernelArgument); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"); } void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace, const StreamConfig& stream_config = StreamConfig{}) const override { - auto p_arg_ = dynamic_cast(p_arg); - p_arg_->p_workspace_ = p_workspace; + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + arg_ptr->p_workspace_ = p_workspace; + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"); hip_check_error( - hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_)); + hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(arg_ptr), stream_config.stream_id_)); } static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } @@ -941,7 +959,26 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK(p_arg), k_batch); + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + arg_ptr->UpdateKBatch(k_batch); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"); + } + + void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override + { + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + arg_ptr->UpdateKBatch(kbatch); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"); } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp index 6d9d1459c8f54e6d506268cf21650dd1010d7d5d..994c667fbcb35eb10897603affcf51a4ac18d3d0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp @@ -40,8 +40,7 @@ __global__ void const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[shared_size]; @@ -80,7 +79,7 @@ __global__ void ignore = a_element_op; ignore = b_element_op; ignore = c_element_op; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } template && arg.K_BATCH > 1 && !is_bf16_atomic_supported()) + { + return false; + } + bool supported = true; for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) { - const auto& a = arg.gemm_kernel_args_[i].karg_; + const auto& a = arg.gemm_kernel_args_[i].karg_; + bool group_arg_valid = GridwiseGemm::CheckValidity(a); if(not group_arg_valid) { @@ -631,16 +636,42 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK(p_arg)->gemm_kernel_args_.size() * - sizeof(GemmTransKernelArg); + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + return p_arg_->gemm_kernel_args_.size() * sizeof(GemmTransKernelArg); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!"); } + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + return GetWorkSpaceSize(p_arg); + } + + // TODO: deperecation notice. static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); } // polymorphic void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override { - return SetKBatchSize(*dynamic_cast(p_arg), kbatch); + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->UpdateKBatch(kbatch); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!"); + } + + void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override + { + return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args); } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp index 648736fcbfafe0cf1ee987521b91af8ac3c876c2..1ad37058db10c8e9146be72f5f8bba6dcc7d6103 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp @@ -3,6 +3,7 @@ #pragma once +#include "ck/library/utility/numeric.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp index 609697cb97a465c8248823b83930f1daa416a7e3..63b49d9aa0c7c519e22b32331a1993d98bc45e3d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp @@ -56,8 +56,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = diff --git a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp index 2202bc5695989041f07ba1bda02bb025104766ba..85adb64b430ed37b51976fafc7738034f6076098 100644 --- a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp +++ b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp @@ -430,6 +430,7 @@ struct G_NDHW : public BaseTensorLayout } // namespace convolution +#ifndef CK_CODE_GEN_RTC template < typename Layout, typename std::enable_if::value, bool>::type = false> @@ -438,6 +439,7 @@ std::ostream& operator<<(std::ostream& os, const Layout&) os << Layout::name; return os; } +#endif } // namespace tensor_layout } // namespace ck diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index c87c90a91dd9f731713553e0d56bd24a9a0b25d2..530876650ee5a676cb7072a8d0f110708480f048 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -340,8 +340,8 @@ struct Bilinear }; template <> - __host__ __device__ constexpr void operator()( - std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const + __host__ __device__ constexpr void + operator()(int8_t& y, const int32_t& x0, const int8_t& x1) const { y = type_convert(alpha_ * type_convert(x0) + beta_ * type_convert(x1)); diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index b914c0b96f7041d44c8ff0761e868c6e7b581658..370d03258da7006e069b5fc7d41a8318eb73c41a 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -533,7 +533,7 @@ struct NormalizeInInfer const T3& gamma, const T4& beta) const { - static_assert(std::is_same::value || std::is_same::value, + static_assert(is_same::value || is_same::value, "Data type is not supported by this operation!"); using ck::type_convert; diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 39b81ca5730124f0799446003b74a98771bb2a93..f1055d1eff8c54dc2a9c6ad082c7ce512c3fefd1 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,36 +7,203 @@ #include "ck/utility/math.hpp" #include "ck/utility/math_v2.hpp" #include "ck/utility/type_convert.hpp" +#include "ck/utility/amd_inline_asm.hpp" #include namespace ck { + +// Fast int4x4 to half8_t data type conversion based on paper +// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production] +// (https://arxiv.org/abs/2211.10017) and implementation: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +// Convert lower part of packed int4 -> int4 to half +__device__ inline half4_t i4_to_half4(int q) +{ + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + + // Extract the two int4 at low bit and create two fp16 number. + int lo = amd_assembly_and_or_b32(q, LO, EX); + // Extract the two int4 at hight bit and create two fp16 number. + int hi = amd_assembly_and_or_b32(q, HI, EX); + + const int SUB = 0xE408E408; // half2 {-1032, -1032} + const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16} + const int ADD = 0xd480d480; // half2 {-72, -72} + + vector_type res; + + // for two fp16 from lowbit, subtract 1032 to get correct fp16 value + res.template AsType()(Number<0>{}) = + amd_assembly_pk_add_f16(bit_cast(lo), bit_cast(SUB)); + + // for two fp16 from highbit, divide 16 and subtract 72 to get correct fp16 value + res.template AsType()(Number<1>{}) = amd_assembly_pk_fma_f16( + bit_cast(hi), bit_cast(MUL), bit_cast(ADD)); + + return res.template AsType()[Number<0>{}]; +} + +__device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale) +{ + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + + // Extract the two int4 at low bit and create two fp16 number. + int lo = amd_assembly_and_or_b32(q, LO, EX); + // Extract the two int4 at hight bit and create two fp16 number. + int hi = amd_assembly_and_or_b32(q, HI, EX); + + const int SUB = 0xE408E408; // half2 {-1032, -1032} + const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16} + const int ADD = 0xd480d480; // half2 {-72, -72} + + vector_type res; + + res.template AsType()(Number<0>{}) = + amd_assembly_pk_add_f16(bit_cast(lo), bit_cast(SUB)); + + res.template AsType()(Number<1>{}) = amd_assembly_pk_fma_f16( + bit_cast(hi), bit_cast(MUL), bit_cast(ADD)); + + asm volatile("v_pk_mul_f16 %0, %1, %2" + : "=v"(res.template AsType()(Number<0>{})) + : "v"(res.template AsType()(Number<0>{})), "v"(scale)); + + asm volatile("v_pk_mul_f16 %0, %1, %2" + : "=v"(res.template AsType()(Number<1>{})) + : "v"(res.template AsType()(Number<1>{})), "v"(scale)); + + return res.template AsType()[Number<0>{}]; +} + +__device__ inline bhalf4_t i4_to_bhalf4(int q) +{ + uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12); + + static constexpr uint32_t fp32_base = 0x4B000000; + + float fp32_intermediates[4]; + + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388616.f; + fp32_intermediates[1] -= 8388616.f; + fp32_intermediates[2] -= 8388616.f; + fp32_intermediates[3] -= 8388616.f; + + vector_type res; + res.template AsType()(Number<0>{}) = bit_cast( + __byte_perm(fp32_intermediates_casted[1], fp32_intermediates_casted[0], 0x7632)); + res.template AsType()(Number<1>{}) = bit_cast( + __byte_perm(fp32_intermediates_casted[3], fp32_intermediates_casted[2], 0x7632)); + + return res.template AsType()[Number<0>{}]; +} + namespace tensor_operation { namespace element_wise { -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wnon-virtual-dtor" -struct UnaryOpBase +struct PassThroughPack8 { - public: - __host__ __device__ ~UnaryOpBase() = default; + template + __host__ __device__ void operator()(Y& y, const X& x) const; + + __host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x) const + { +#if CK_USE_PK4_LAYOUT_SHUFFLE + vector_type result; + + result.template AsType()(Number<0>{}) = i4_to_half4(bit_cast(x)); + result.template AsType()(Number<1>{}) = i4_to_half4(bit_cast(x) >> 8); - __host__ __device__ constexpr UnaryOpBase() = default; - __host__ __device__ constexpr UnaryOpBase(const UnaryOpBase&) = default; - __host__ __device__ constexpr UnaryOpBase(UnaryOpBase&&) = default; - __host__ __device__ UnaryOpBase& operator=(const UnaryOpBase&) = default; - __host__ __device__ UnaryOpBase& operator=(UnaryOpBase&&) = default; + y = result.template AsType()[Number<0>{}]; +#else + vector_type dst; + vector_type src{x}; - __host__ __device__ virtual inline void operator()(float& y, const float& x) const = 0; + dst.template AsType()(Number<0>{}) = + type_convert(src.template AsType()[Number<0>{}]); + dst.template AsType()(Number<1>{}) = + type_convert(src.template AsType()[Number<1>{}]); + dst.template AsType()(Number<2>{}) = + type_convert(src.template AsType()[Number<2>{}]); + dst.template AsType()(Number<3>{}) = + type_convert(src.template AsType()[Number<3>{}]); - __host__ __device__ virtual inline void operator()(double& y, const double& x) const = 0; + y = dst.template AsType()[Number<0>{}]; +#endif + } + + __host__ __device__ constexpr void operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x) const + { +#if CK_USE_PK4_LAYOUT_SHUFFLE + vector_type result; - __host__ __device__ virtual inline void operator()(int32_t& y, const int32_t& x) const = 0; + result.template AsType()(Number<0>{}) = i4_to_bhalf4(bit_cast(x)); + result.template AsType()(Number<1>{}) = i4_to_bhalf4(bit_cast(x) >> 16); - __host__ __device__ virtual inline void operator()(int8_t& y, const int8_t& x) const = 0; + y = result.template AsType()[Number<0>{}]; +#else + vector_type dst; + vector_type src{x}; - __host__ __device__ virtual inline void operator()(half_t& y, const half_t& x) const = 0; + dst.template AsType()(Number<0>{}) = + type_convert(src.template AsType()[Number<0>{}]); + dst.template AsType()(Number<1>{}) = + type_convert(src.template AsType()[Number<1>{}]); + dst.template AsType()(Number<2>{}) = + type_convert(src.template AsType()[Number<2>{}]); + dst.template AsType()(Number<3>{}) = + type_convert(src.template AsType()[Number<3>{}]); - __host__ __device__ virtual inline void operator()(bhalf_t& y, const bhalf_t& x) const = 0; + y = dst.template AsType()[Number<0>{}]; +#endif + } + constexpr const static bool is_pack8_invocable = true; +}; + +struct DequantPack8 +{ + template + __host__ __device__ void operator()(Y& y, const X& x, const Z& z) const; + + __host__ __device__ constexpr void + operator()(ck::half8_t& y, const ck::pk_i4x4_t& x, const ck::half2_t& z) const + { +#if CK_USE_PK4_LAYOUT_SHUFFLE + vector_type result; + + result.template AsType()(Number<0>{}) = i4_to_half4_scale(bit_cast(x), z); + result.template AsType()(Number<1>{}) = + i4_to_half4_scale(bit_cast(x) >> 8, z); + + y = result.template AsType()[Number<0>{}]; +#else + vector_type dst; + vector_type src{x}; + + dst.template AsType()(Number<0>{}) = + type_convert(src.template AsType()[Number<0>{}]); + dst.template AsType()(Number<1>{}) = + type_convert(src.template AsType()[Number<1>{}]); + dst.template AsType()(Number<2>{}) = + type_convert(src.template AsType()[Number<2>{}]); + dst.template AsType()(Number<3>{}) = + type_convert(src.template AsType()[Number<3>{}]); + + y = dst.template AsType()[Number<0>{}]; +#endif + } + + constexpr const static bool is_pack8_invocable = true; }; struct PassThroughPack2 @@ -44,38 +211,49 @@ struct PassThroughPack2 template __host__ __device__ void operator()(Y& y, const X& x) const; - __host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::f8x2_t& x) const + __host__ __device__ constexpr void operator()(half2_t& y, const f8x2_t& x) const { auto t = type_convert(x); y = type_convert(t); } - constexpr const static bool is_pack2_invocable = true; -}; -struct PassThrough final : public UnaryOpBase -{ - __host__ __device__ constexpr PassThrough() = default; - __host__ __device__ constexpr PassThrough(const PassThrough&) = default; - __host__ __device__ constexpr PassThrough(PassThrough&&) = default; - __host__ __device__ PassThrough& operator=(const PassThrough&) = default; - __host__ __device__ PassThrough& operator=(PassThrough&&) = default; - __host__ __device__ ~PassThrough() = default; - - __host__ __device__ inline void operator()(float& y, const float& x) const final { y = x; } - - __host__ __device__ inline void operator()(double& y, const double& x) const final { y = x; } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final { y = x; } + __host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::pk_i4_t& x) const + { +#if CK_USE_PK4_LAYOUT_SHUFFLE + uint8_t x_u8 = ck::bit_cast(x); + uint8_t x_l = (x_u8 & 0x0f) >> 0; + uint8_t x_h = (x_u8 & 0xf0) >> 4; - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final { y = x; } + auto l_f16 = ck::type_convert(x_l); + auto h_f16 = ck::type_convert(x_h); - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final { y = x; } + y = {l_f16, h_f16}; +#else + uint32_t t = ck::bit_cast(x); + y = ck::bit_cast(t); +#endif + } - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final { y = x; } + constexpr const static bool is_pack2_invocable = true; +}; +struct PassThrough +{ template __host__ __device__ void operator()(Y& y, const X& x) const; + template <> + __host__ __device__ void operator()(pk_i4_t& y, const pk_i4_t& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(double& y, const double& x) const + { + y = x; + } + template <> __host__ __device__ void operator()(float& y, const double& x) const { @@ -88,12 +266,36 @@ struct PassThrough final : public UnaryOpBase y = type_convert(x); } + template <> + __host__ __device__ void operator()(float& y, const float& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(half_t& y, const half_t& x) const + { + y = x; + } + template <> __host__ __device__ void operator()(half_t& y, const float& x) const { y = type_convert(x); } + template <> + __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(int32_t& y, const int32_t& x) const + { + y = x; + } + template <> __host__ __device__ void operator()(bhalf_t& y, const float& x) const { @@ -118,6 +320,12 @@ struct PassThrough final : public UnaryOpBase y = type_convert(x); } + template <> + __host__ __device__ void operator()(int8_t& y, const int8_t& x) const + { + y = x; + } + template <> __host__ __device__ void operator()(half_t& y, const int8_t& x) const { @@ -230,7 +438,7 @@ struct PassThrough final : public UnaryOpBase template <> __host__ __device__ void operator()(bf8_t& y, const half_t& x) const { - y = ck::type_convert(x); + y = type_convert(x); } }; @@ -303,21 +511,21 @@ struct Scale template __host__ __device__ void operator()(Y& y, const X& x) const { - y = ck::type_convert(ck::type_convert(x) * scale_); + y = type_convert(type_convert(x) * scale_); } template <> __host__ __device__ void operator()(half_t& y, const half_t& x) const { - y = ck::type_convert(scale_) * x; + y = type_convert(scale_) * x; }; template <> __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { - const float x_tmp = ck::type_convert(x); + const float x_tmp = type_convert(x); const float y_tmp = scale_ * x_tmp; - y = ck::type_convert(y_tmp); + y = type_convert(y_tmp); }; template <> @@ -335,7 +543,7 @@ struct Scale template <> __host__ __device__ void operator()(int8_t& y, const int8_t& x) const { - y = ck::type_convert(scale_ * ck::type_convert(x)); + y = type_convert(scale_ * type_convert(x)); }; float scale_; @@ -351,7 +559,7 @@ struct ScaleAndResetNaNToMinusInfinity template <> __host__ __device__ void operator()(float& y, const float& x) const { - y = ck::math::isnan(x) ? -ck::NumericLimits::Infinity() : scale_ * x; + y = math::isnan(x) ? -NumericLimits::Infinity() : scale_ * x; }; float scale_; @@ -417,45 +625,21 @@ struct UnarySquare }; }; -struct UnaryAbs final : public UnaryOpBase +struct UnaryAbs { - __host__ __device__ constexpr UnaryAbs() = default; - __host__ __device__ constexpr UnaryAbs(const UnaryAbs&) = default; - __host__ __device__ constexpr UnaryAbs(UnaryAbs&&) = default; - __host__ __device__ UnaryAbs& operator=(const UnaryAbs&) = default; - __host__ __device__ UnaryAbs& operator=(UnaryAbs&&) = default; - __host__ __device__ ~UnaryAbs() = default; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - y = ck::math::abs(x); - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - y = ck::math::abs(x); - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - y = ck::math::abs(x); - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final + template + __host__ __device__ void operator()(T& y, const T& x) const { - y = ck::math::abs(x); - } - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - y = ck::math::abs(x); - } + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final - { - y = ck::math::abs(x); - } + y = math::abs(x); + }; + template <> __host__ __device__ void operator()(f8_t& y, const f8_t& x) const { y = ck::type_convert(ck::math::abs(ck::type_convert(x))); @@ -470,49 +654,28 @@ struct UnarySqrt static_assert(is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::sqrt(x); + y = math::sqrt(x); }; }; -struct Relu final : public UnaryOpBase +struct Relu { - __host__ __device__ constexpr Relu() = default; - __host__ __device__ constexpr Relu(const Relu&) = default; - __host__ __device__ constexpr Relu(Relu&&) = default; - __host__ __device__ Relu& operator=(const Relu&) = default; - __host__ __device__ Relu& operator=(Relu&&) = default; - __host__ __device__ ~Relu() = default; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - y = x > 0 ? x : 0; - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - y = x > 0 ? x : 0; - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - y = x > 0 ? x : 0; - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - y = x > 0 ? x : 0; - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final + template + __host__ __device__ void operator()(T& y, const T& x) const { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); y = x > 0 ? x : 0; } - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final + template <> + __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { - float x_f32 = ck::type_convert(x); + float x_f32 = type_convert(x); float y_f32 = x_f32 > 0 ? x_f32 : 0; - y = ck::type_convert(y_f32); + y = type_convert(y_f32); } }; @@ -528,7 +691,7 @@ struct FastGelu template __device__ void operator()(Y& y, const X& x) const; - +#ifndef CK_CODE_GEN_RTC template <> __host__ void operator()(float& y, const float& x) const { @@ -539,6 +702,7 @@ struct FastGelu const float emu = exp(u); y = x / (1.f + emu); } +#endif // device code, use lower precision "__ocml_exp_f32" and "rcp" template <> @@ -550,7 +714,7 @@ struct FastGelu const float u = x * (c1 * x * x + c2); const float emu = __ocml_exp_f32(u); - y = x * ck::math::rcp(1.f + emu); + y = x * math::rcp(1.f + emu); } template <> @@ -648,59 +812,24 @@ struct Gelu } template <> - __host__ __device__ void operator()(ck::half_t& y, - const ck::half_t& x) const + __host__ __device__ void operator()(half_t& y, const half_t& x) const { - y = ck::half_t(0.5) * x * (ck::half_t(1) + ck::half_t(erf(float(0.70710678118f * x)))); + y = half_t(0.5) * x * (half_t(1) + half_t(erf(float(0.70710678118f * x)))); } }; -struct Sigmoid final : public UnaryOpBase +struct Sigmoid { - __host__ __device__ constexpr Sigmoid() = default; - __host__ __device__ constexpr Sigmoid(const Sigmoid&) = default; - __host__ __device__ constexpr Sigmoid(Sigmoid&&) = default; - __host__ __device__ Sigmoid& operator=(const Sigmoid&) = default; - __host__ __device__ Sigmoid& operator=(Sigmoid&&) = default; - __host__ __device__ ~Sigmoid() = default; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - constexpr float one = type_convert(1); - y = one / (one + ck::math::exp(-x)); - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - constexpr double one = type_convert(1); - y = one / (one + ck::math::exp(-x)); - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - constexpr int32_t one = type_convert(1); - y = one / (one + ck::math::exp(-x)); - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - constexpr int8_t one = type_convert(1); - y = one / (one + ck::math::exp(-x)); - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - constexpr half_t one = type_convert(1); - y = one / (one + ck::math::exp(-x)); - } - - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final + template + __host__ __device__ void operator()(T& y, const T& x) const { - constexpr float one = type_convert(1); - float x_f32 = ck::type_convert(x); - float y_f32 = one / (one + ck::math::exp(x_f32)); - y = ck::type_convert(y_f32); - } + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + constexpr T one = type_convert(1); + y = one / (one + math::exp(-x)); + }; }; struct Silu @@ -708,52 +837,26 @@ struct Silu template __host__ __device__ void operator()(T& y, const T& x) const { - static_assert(is_same_v || is_same_v || is_same_v || + static_assert(is_same_v || is_same_v || is_same_v || is_same_v || is_same_v, "Data type is not supported by this operation!"); constexpr T one = type_convert(1); - y = x * (one / (one + ck::math::exp(-x))); + y = x * (one / (one + math::exp(-x))); }; }; -struct TanH final : public UnaryOpBase +struct TanH { - __host__ __device__ constexpr TanH() = default; - __host__ __device__ constexpr TanH(const TanH&) = default; - __host__ __device__ constexpr TanH(TanH&&) = default; - __host__ __device__ TanH& operator=(const TanH&) = default; - __host__ __device__ TanH& operator=(TanH&&) = default; - __host__ __device__ ~TanH() = default; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - y = ck::math::tanh(x); - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - y = ck::math::tanh(x); - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - y = ck::math::tanh(x); - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - y = ck::math::tanh(x); - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final + template + __host__ __device__ void operator()(T& y, const T& x) const { - y = ck::math::tanh(x); - } + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final - { - y = ck::math::tanh(x); - } + y = math::tanh(x); + }; }; struct ACos @@ -762,11 +865,11 @@ struct ACos __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::acos(x); + y = math::acos(x); }; }; @@ -776,11 +879,11 @@ struct Neg __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::neg(x); + y = math::neg(x); }; }; @@ -790,11 +893,11 @@ struct ATan __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::atan(x); + y = math::atan(x); }; }; @@ -804,11 +907,11 @@ struct Sin __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::sin(x); + y = math::sin(x); }; }; @@ -818,11 +921,11 @@ struct ASinH __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::asinh(x); + y = math::asinh(x); }; }; @@ -832,11 +935,11 @@ struct Cos __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::cos(x); + y = cos(x); }; }; @@ -846,11 +949,11 @@ struct ACosH __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::acosh(x); + y = math::acosh(x); }; }; @@ -860,11 +963,11 @@ struct Tan __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::tan(x); + y = math::tan(x); }; }; @@ -874,11 +977,11 @@ struct ATanH __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::atanh(x); + y = math::atanh(x); }; }; @@ -888,11 +991,11 @@ struct SinH __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::sinh(x); + y = math::sinh(x); }; }; @@ -902,11 +1005,11 @@ struct Ceil __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::ceil(x); + y = math::ceil(x); }; }; @@ -916,11 +1019,11 @@ struct Exp __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::exp(x); + y = math::exp(x); }; }; @@ -930,11 +1033,11 @@ struct CosH __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::cosh(x); + y = math::cosh(x); }; }; @@ -944,11 +1047,11 @@ struct Floor __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::floor(x); + y = math::floor(x); }; }; @@ -958,11 +1061,11 @@ struct Log __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::log(x); + y = math::log(x); }; }; @@ -972,11 +1075,11 @@ struct ASin __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::asin(x); + y = math::asin(x); }; }; @@ -986,426 +1089,146 @@ struct Rcp __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::rcp(x); + y = math::rcp(x); }; }; -struct Swish final : public UnaryOpBase +struct Swish { - __host__ __device__ constexpr Swish(const Swish&) = default; - __host__ __device__ constexpr Swish(Swish&&) = default; - __host__ __device__ ~Swish() = default; - - __host__ __device__ Swish(float beta = 1.0f) : beta_(beta) {} - - __host__ __device__ float get_beta() const { return beta_; } - - const float beta_; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - float bx = -beta_ * type_convert(x); - y = type_convert(x / (1.f + ck::math::exp(bx))); - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - float bx = -beta_ * type_convert(x); - y = type_convert(x / (1.f + ck::math::exp(bx))); - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - float bx = -beta_ * type_convert(x); - y = type_convert(x / (1.f + ck::math::exp(bx))); - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - float bx = -beta_ * type_convert(x); - y = type_convert(x / (1.f + ck::math::exp(bx))); - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - float bx = -beta_ * type_convert(x); - y = type_convert(x / (1.f + ck::math::exp(bx))); - } - - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final - { - float bx = -beta_ * type_convert(x); - y = type_convert(x / (1.f + ck::math::exp(bx))); - } + Swish(float beta = 1.0f) : beta_(beta) {} template __host__ __device__ void operator()(Y& y, const X& x) const { static_assert(is_same::value || is_same::value || - is_same::value, + is_same::value || is_same::value, "Data type is not supported by this operation!"); static_assert(is_same::value || is_same::value || - is_same::value, + is_same::value || is_same::value, "Data type is not supported by this operation!"); float bx = -beta_ * type_convert(x); - y = type_convert(x / (1.f + ck::math::exp(bx))); - } + y = type_convert(x / (1.f + math::exp(bx))); + }; + + const float beta_; }; -struct SoftRelu final : public UnaryOpBase +struct SoftRelu { - __host__ __device__ constexpr SoftRelu(const SoftRelu&) = default; - __host__ __device__ constexpr SoftRelu(SoftRelu&&) = default; - __host__ __device__ ~SoftRelu() = default; + SoftRelu(float alpha = 1.f) : alpha_(alpha){}; - __host__ __device__ SoftRelu(float alpha = 1.0f) : alpha_(alpha) {} - - __host__ __device__ float get_alpha() const { return alpha_; } - - const float alpha_; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - float casted_alpha = type_convert(alpha_); - constexpr float one = type_convert(1); - y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - double casted_alpha = type_convert(alpha_); - constexpr double one = type_convert(1); - y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - int32_t casted_alpha = type_convert(alpha_); - constexpr int32_t one = type_convert(1); - y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - int8_t casted_alpha = type_convert(alpha_); - constexpr int8_t one = type_convert(1); - y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - half_t casted_alpha = type_convert(alpha_); - constexpr half_t one = type_convert(1); - y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; - } - - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final + template + __host__ __device__ void operator()(T& y, const T& x) const { - bhalf_t casted_alpha = type_convert(alpha_); - constexpr bhalf_t one = type_convert(1); - y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + constexpr T one = type_convert(1); + y = math::log(one + math::exp(x * casted_alpha)) / casted_alpha; } + const float alpha_; }; -struct Power final : public UnaryOpBase +struct Power { - __host__ __device__ constexpr Power(const Power&) = default; - __host__ __device__ constexpr Power(Power&&) = default; - __host__ __device__ ~Power() = default; + Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) + : alpha_(alpha), beta_(beta), gamma_(gamma){}; - __host__ __device__ Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) - : alpha_(alpha), beta_(beta), gamma_(gamma) + template + __host__ __device__ void operator()(T& y, const T& x) const { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + T casted_beta = type_convert(beta_); + T casted_gamma = type_convert(gamma_); + T shifted_scaled_x = casted_alpha + casted_beta * x; + y = math::pow(shifted_scaled_x, casted_gamma); } - - __host__ __device__ float get_alpha() const { return alpha_; } - - __host__ __device__ float get_beta() const { return beta_; } - - __host__ __device__ float get_gamma() const { return gamma_; } - const float alpha_; const float beta_; const float gamma_; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - float casted_alpha = type_convert(alpha_); - float casted_beta = type_convert(beta_); - float casted_gamma = type_convert(gamma_); - - float shifted_scaled_x = casted_alpha + casted_beta * x; - y = ck::math::pow(shifted_scaled_x, casted_gamma); - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - double casted_alpha = type_convert(alpha_); - double casted_beta = type_convert(beta_); - double casted_gamma = type_convert(gamma_); - - double shifted_scaled_x = casted_alpha + casted_beta * x; - y = ck::math::pow(shifted_scaled_x, casted_gamma); - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - int32_t casted_alpha = type_convert(alpha_); - int32_t casted_beta = type_convert(beta_); - int32_t casted_gamma = type_convert(gamma_); - - int32_t shifted_scaled_x = casted_alpha + casted_beta * x; - y = ck::math::pow(shifted_scaled_x, casted_gamma); - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - int8_t casted_alpha = type_convert(alpha_); - int8_t casted_beta = type_convert(beta_); - int8_t casted_gamma = type_convert(gamma_); - - int8_t shifted_scaled_x = casted_alpha + casted_beta * x; - y = ck::math::pow(shifted_scaled_x, casted_gamma); - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - half_t casted_alpha = type_convert(alpha_); - half_t casted_beta = type_convert(beta_); - half_t casted_gamma = type_convert(gamma_); - - half_t shifted_scaled_x = casted_alpha + casted_beta * x; - y = ck::math::pow(shifted_scaled_x, casted_gamma); - } - - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final - { - bhalf_t casted_alpha = type_convert(alpha_); - bhalf_t casted_beta = type_convert(beta_); - bhalf_t casted_gamma = type_convert(gamma_); - - bhalf_t shifted_scaled_x = casted_alpha + casted_beta * x; - y = ck::math::pow(shifted_scaled_x, casted_gamma); - } }; -struct ClippedRelu final : public UnaryOpBase +struct ClippedRelu { - __host__ __device__ constexpr ClippedRelu(const ClippedRelu&) = default; - __host__ __device__ constexpr ClippedRelu(ClippedRelu&&) = default; - __host__ __device__ ~ClippedRelu() = default; + ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; - __host__ __device__ ClippedRelu(float alpha = 0.f, float beta = 1.f) - : alpha_(alpha), beta_(beta) + template + __host__ __device__ void operator()(T& y, const T& x) const { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + T casted_beta = type_convert(beta_); + y = math::min(casted_beta, math::max(casted_alpha, x)); } - - __host__ __device__ float get_alpha() const { return alpha_; } - - __host__ __device__ float get_beta() const { return beta_; } - const float alpha_; const float beta_; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - float casted_alpha = type_convert(alpha_); - float casted_beta = type_convert(beta_); - y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - double casted_alpha = type_convert(alpha_); - double casted_beta = type_convert(beta_); - y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - int32_t casted_alpha = type_convert(alpha_); - int32_t casted_beta = type_convert(beta_); - y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - int8_t casted_alpha = type_convert(alpha_); - int8_t casted_beta = type_convert(beta_); - y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - half_t casted_alpha = type_convert(alpha_); - half_t casted_beta = type_convert(beta_); - y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); - } - - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final - { - bhalf_t casted_alpha = type_convert(alpha_); - bhalf_t casted_beta = type_convert(beta_); - y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); - } }; -struct LeakyRelu final : public UnaryOpBase +struct LeakyRelu { - __host__ __device__ constexpr LeakyRelu(const LeakyRelu&) = default; - __host__ __device__ constexpr LeakyRelu(LeakyRelu&&) = default; - __host__ __device__ ~LeakyRelu() = default; - - __host__ __device__ LeakyRelu(float alpha = 0.f) : alpha_(alpha) {} - - __host__ __device__ float get_alpha() const { return alpha_; } - - const float alpha_; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - float casted_alpha = type_convert(alpha_); - y = x >= 0 ? x : x * casted_alpha; - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - double casted_alpha = type_convert(alpha_); - y = x >= 0 ? x : x * casted_alpha; - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - int32_t casted_alpha = type_convert(alpha_); - y = x >= 0 ? x : x * casted_alpha; - } + LeakyRelu(float alpha = 0.01f) : alpha_(alpha){}; - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - int8_t casted_alpha = type_convert(alpha_); - y = x >= 0 ? x : x * casted_alpha; - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - half_t casted_alpha = type_convert(alpha_); - y = x >= 0 ? x : x * casted_alpha; - } - - __host__ __device__ inline void operator()([[maybe_unused]] bhalf_t& y, - [[maybe_unused]] const bhalf_t& x) const final + template + __host__ __device__ void operator()(T& y, const T& x) const { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + y = x >= 0 ? x : x * casted_alpha; } + const float alpha_; }; -struct Elu final : public UnaryOpBase +struct Elu { - __host__ __device__ constexpr Elu(const Elu&) = default; - __host__ __device__ constexpr Elu(Elu&&) = default; - __host__ __device__ ~Elu() = default; - - __host__ __device__ Elu(float alpha = 1.f) : alpha_(alpha) {} - - __host__ __device__ float get_alpha() const { return alpha_; } + Elu(float alpha = 1.f) : alpha_(alpha){}; - const float alpha_; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - float casted_alpha = type_convert(alpha_); - y = x > 0 ? x : casted_alpha * ck::math::expm1(x); - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - double casted_alpha = type_convert(alpha_); - y = x > 0 ? x : casted_alpha * ck::math::expm1(x); - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - int32_t casted_alpha = type_convert(alpha_); - y = x > 0 ? x : casted_alpha * ck::math::expm1(x); - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - int8_t casted_alpha = type_convert(alpha_); - y = x > 0 ? x : casted_alpha * ck::math::expm1(x); - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - half_t casted_alpha = type_convert(alpha_); - y = x > 0 ? x : casted_alpha * ck::math::expm1(x); - } - - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final + template + __host__ __device__ void operator()(T& y, const T& x) const { - bhalf_t casted_alpha = type_convert(alpha_); - y = x > 0 ? x : casted_alpha * ck::math::expm1(x); + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + y = x > 0 ? x : casted_alpha * math::expm1(x); } + const float alpha_; }; -struct Logistic final : public UnaryOpBase +struct Logistic { - __host__ __device__ constexpr Logistic(const Logistic&) = default; - __host__ __device__ constexpr Logistic(Logistic&&) = default; - __host__ __device__ ~Logistic() = default; - - __host__ __device__ Logistic(float alpha = 1.0f) : alpha_(alpha) {} - - __host__ __device__ float get_alpha() const { return alpha_; } + Logistic(float alpha = 1.f) : alpha_(alpha){}; - const float alpha_; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - float casted_alpha = type_convert(alpha_); - constexpr float one = type_convert(1); - y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - double casted_alpha = type_convert(alpha_); - constexpr double one = type_convert(1); - y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - int32_t casted_alpha = type_convert(alpha_); - constexpr int32_t one = type_convert(1); - y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - int8_t casted_alpha = type_convert(alpha_); - constexpr int8_t one = type_convert(1); - y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - half_t casted_alpha = type_convert(alpha_); - constexpr half_t one = type_convert(1); - y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); - } - - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final + template + __host__ __device__ void operator()(T& y, const T& x) const { - bhalf_t casted_alpha = type_convert(alpha_); - constexpr bhalf_t one = type_convert(1); - y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + constexpr T one = type_convert(1); + y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); } + const float alpha_; }; struct ConvInvscale @@ -1470,7 +1293,7 @@ struct ConvScaleRelu __host__ __device__ void operator()(f8_t& e, const float& c) const { float x; - Relu{}(x, c * scale_in_ * scale_wei_); + Relu{}.template operator()(x, c * scale_in_ * scale_wei_); e = type_convert(x * scale_out_); }; @@ -1487,10 +1310,10 @@ struct FastNumericArrayConverter }; template <> -struct FastNumericArrayConverter +struct FastNumericArrayConverter { using InputArray = vector_type; - using OutputArray = vector_type; + using OutputArray = vector_type; __device__ static OutputArray convert(InputArray const& Input) { @@ -1520,13 +1343,13 @@ struct FastNumericArrayConverter }; template -struct FastNumericArrayConverter +struct FastNumericArrayConverter { static constexpr int VEC_WIDTH = 4; static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); using InputArray = vector_type; - using OutputArray = vector_type; + using OutputArray = vector_type; __device__ static OutputArray convert(InputArray const& Input) { @@ -1535,7 +1358,7 @@ struct FastNumericArrayConverter OutputArray Output; using Vec_InputArray = vector_type; - using Vec_OutputArray = vector_type; + using Vec_OutputArray = vector_type; Vec_OutputArray* half_4_ptr = reinterpret_cast(&Output); Vec_InputArray const* uint8_4_ptr = reinterpret_cast(&Input); @@ -1551,225 +1374,138 @@ struct FastNumericArrayConverter struct DynamicUnaryOp { - - DynamicUnaryOp& operator=(const DynamicUnaryOp& other) - { - if(this != &other) - { - unary_op_ptr_ = other.unary_op_ptr_; - unary_op_type_ = other.unary_op_type_; - } - return *this; - } - __host__ __device__ DynamicUnaryOp() = delete; __host__ __device__ DynamicUnaryOp(const Swish& swish) + : unary_op_type_(UnaryOpType::Swish), swish_{swish.beta_} { - unary_op_type_ = UnaryOpType::Swish; - beta = swish.get_beta(); } __host__ __device__ DynamicUnaryOp(const Swish&& swish) + : unary_op_type_(UnaryOpType::Swish), swish_{swish.beta_} { - unary_op_type_ = UnaryOpType::Swish; - beta = swish.get_beta(); } - __host__ __device__ DynamicUnaryOp(const Sigmoid&) { unary_op_type_ = UnaryOpType::Sigmoid; } + __host__ __device__ DynamicUnaryOp(const Sigmoid&) : unary_op_type_(UnaryOpType::Sigmoid) {} - __host__ __device__ DynamicUnaryOp(const Sigmoid&&) { unary_op_type_ = UnaryOpType::Sigmoid; } + __host__ __device__ DynamicUnaryOp(const Sigmoid&&) : unary_op_type_(UnaryOpType::Sigmoid) {} __host__ __device__ DynamicUnaryOp(const PassThrough&) + : unary_op_type_(UnaryOpType::PassThrough) { - unary_op_type_ = UnaryOpType::PassThrough; } __host__ __device__ DynamicUnaryOp(const PassThrough&&) + : unary_op_type_(UnaryOpType::PassThrough) { - unary_op_type_ = UnaryOpType::PassThrough; } __host__ __device__ DynamicUnaryOp(const Logistic& logistic) + : unary_op_type_(UnaryOpType::Logistic), logistic_{logistic.alpha_} { - unary_op_type_ = UnaryOpType::Logistic; - alpha = logistic.get_alpha(); } __host__ __device__ DynamicUnaryOp(const Logistic&& logistic) + : unary_op_type_(UnaryOpType::Logistic), logistic_{logistic.alpha_} { - unary_op_type_ = UnaryOpType::Logistic; - alpha = logistic.get_alpha(); } - __host__ __device__ DynamicUnaryOp(const TanH&) { unary_op_type_ = UnaryOpType::TanH; } + __host__ __device__ DynamicUnaryOp(const TanH&) : unary_op_type_(UnaryOpType::TanH) {} - __host__ __device__ DynamicUnaryOp(const TanH&&) { unary_op_type_ = UnaryOpType::TanH; } + __host__ __device__ DynamicUnaryOp(const TanH&&) : unary_op_type_(UnaryOpType::TanH) {} - __host__ __device__ DynamicUnaryOp(const Relu&) { unary_op_type_ = UnaryOpType::Relu; } + __host__ __device__ DynamicUnaryOp(const Relu&) : unary_op_type_(UnaryOpType::Relu) {} - __host__ __device__ DynamicUnaryOp(const Relu&&) { unary_op_type_ = UnaryOpType::Relu; } + __host__ __device__ DynamicUnaryOp(const Relu&&) : unary_op_type_(UnaryOpType::Relu) {} __host__ __device__ DynamicUnaryOp(const SoftRelu& softrelu) + : unary_op_type_(UnaryOpType::SoftRelu), soft_relu_{softrelu.alpha_} { - unary_op_type_ = UnaryOpType::SoftRelu; - alpha = softrelu.get_alpha(); } __host__ __device__ DynamicUnaryOp(const SoftRelu&& softrelu) + : unary_op_type_(UnaryOpType::SoftRelu), soft_relu_{softrelu.alpha_} { - unary_op_type_ = UnaryOpType::SoftRelu; - alpha = softrelu.get_alpha(); } - __host__ __device__ DynamicUnaryOp(const UnaryAbs&) { unary_op_type_ = UnaryOpType::UnaryAbs; } + __host__ __device__ DynamicUnaryOp(const UnaryAbs&) : unary_op_type_(UnaryOpType::UnaryAbs) {} - __host__ __device__ DynamicUnaryOp(const UnaryAbs&&) { unary_op_type_ = UnaryOpType::UnaryAbs; } + __host__ __device__ DynamicUnaryOp(const UnaryAbs&&) : unary_op_type_(UnaryOpType::UnaryAbs) {} __host__ __device__ DynamicUnaryOp(const Power& pow) + : unary_op_type_(UnaryOpType::Power), power_(pow.alpha_, pow.beta_, pow.gamma_) { - unary_op_type_ = UnaryOpType::Power; - alpha = pow.get_alpha(); - beta = pow.get_beta(); - gamma = pow.get_gamma(); } __host__ __device__ DynamicUnaryOp(const Power&& pow) + : unary_op_type_(UnaryOpType::Power), power_(pow.alpha_, pow.beta_, pow.gamma_) { - unary_op_type_ = UnaryOpType::Power; - alpha = pow.get_alpha(); - beta = pow.get_beta(); - gamma = pow.get_gamma(); } __host__ __device__ DynamicUnaryOp(const ClippedRelu& clippedrelu) + : unary_op_type_(UnaryOpType::ClippedRelu), + clipped_relu_{clippedrelu.alpha_, clippedrelu.beta_} { - unary_op_type_ = UnaryOpType::ClippedRelu; - alpha = clippedrelu.get_alpha(); - beta = clippedrelu.get_beta(); } __host__ __device__ DynamicUnaryOp(const ClippedRelu&& clippedrelu) + : unary_op_type_(UnaryOpType::ClippedRelu), + clipped_relu_{clippedrelu.alpha_, clippedrelu.beta_} { - unary_op_type_ = UnaryOpType::ClippedRelu; - alpha = clippedrelu.get_alpha(); - beta = clippedrelu.get_beta(); } __host__ __device__ DynamicUnaryOp(const LeakyRelu& leakyrelu) + : unary_op_type_(UnaryOpType::LeakyRelu), leaky_relu_{leakyrelu.alpha_} { - unary_op_type_ = UnaryOpType::LeakyRelu; - alpha = leakyrelu.get_alpha(); } __host__ __device__ DynamicUnaryOp(const LeakyRelu&& leakyrelu) + : unary_op_type_(UnaryOpType::LeakyRelu), leaky_relu_{leakyrelu.alpha_} { - unary_op_type_ = UnaryOpType::LeakyRelu; - alpha = leakyrelu.get_alpha(); } __host__ __device__ DynamicUnaryOp(const Elu& elu) + : unary_op_type_(UnaryOpType::Elu), elu_{elu.alpha_} { - unary_op_type_ = UnaryOpType::Elu; - alpha = elu.get_alpha(); } __host__ __device__ DynamicUnaryOp(const Elu&& elu) + : unary_op_type_(UnaryOpType::Elu), elu_{elu.alpha_} { - unary_op_type_ = UnaryOpType::Elu; - alpha = elu.get_alpha(); } - __host__ __device__ DynamicUnaryOp(const DynamicUnaryOp& dynamic_op) - : unary_op_type_(dynamic_op.unary_op_type_), - unary_op_ptr_(dynamic_op.unary_op_ptr_), - alpha(dynamic_op.alpha), - beta(dynamic_op.beta), - gamma(dynamic_op.gamma) - { - } - - __host__ __device__ ~DynamicUnaryOp() - { - switch(unary_op_type_) - { - case(UnaryOpType::Swish): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::Sigmoid): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::PassThrough): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::Logistic): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::TanH): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::Relu): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::SoftRelu): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::UnaryAbs): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::Power): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::ClippedRelu): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::LeakyRelu): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::Elu): delete static_cast(unary_op_ptr_); break; - - default: break; - } - } - - __device__ void InitUnaryOpPtrOnDevice() - { - switch(unary_op_type_) - { - case(UnaryOpType::Swish): unary_op_ptr_ = new Swish(beta); break; - case(UnaryOpType::Sigmoid): unary_op_ptr_ = new Sigmoid; break; - case(UnaryOpType::PassThrough): unary_op_ptr_ = new PassThrough; break; - case(UnaryOpType::Logistic): unary_op_ptr_ = new Logistic(alpha); break; - case(UnaryOpType::TanH): unary_op_ptr_ = new TanH; break; - case(UnaryOpType::Relu): unary_op_ptr_ = new Relu; break; - case(UnaryOpType::SoftRelu): unary_op_ptr_ = new SoftRelu(alpha); break; - case(UnaryOpType::UnaryAbs): unary_op_ptr_ = new UnaryAbs; break; - case(UnaryOpType::Power): unary_op_ptr_ = new Power(alpha, beta, gamma); break; - case(UnaryOpType::ClippedRelu): unary_op_ptr_ = new ClippedRelu(alpha, beta); break; - case(UnaryOpType::LeakyRelu): unary_op_ptr_ = new LeakyRelu(alpha); break; - case(UnaryOpType::Elu): unary_op_ptr_ = new Elu(alpha); break; - - default: unary_op_ptr_ = nullptr; break; - } - } + __host__ __device__ DynamicUnaryOp(const DynamicUnaryOp& dynamic_op) = default; - template - __device__ void operator()(Y& y, const X& x) const - { - isSupported(); - unary_op_ptr_->operator()(y, x); - } + __host__ __device__ ~DynamicUnaryOp() {} template - __host__ void operator()(Y& y, const X& x) const + __host__ __device__ void operator()(Y& y, const X& x) const { - isSupported(); switch(unary_op_type_) { - case(UnaryOpType::Swish): Swish{}.operator()(y, x); break; - case(UnaryOpType::Sigmoid): Sigmoid{}.operator()(y, x); break; - case(UnaryOpType::PassThrough): PassThrough{}.operator()(y, x); break; - case(UnaryOpType::Logistic): Logistic{}.operator()(y, x); break; - case(UnaryOpType::TanH): TanH{}.operator()(y, x); break; - case(UnaryOpType::Relu): Relu{}.operator()(y, x); break; - case(UnaryOpType::SoftRelu): SoftRelu{}.operator()(y, x); break; - case(UnaryOpType::UnaryAbs): UnaryAbs{}.operator()(y, x); break; - case(UnaryOpType::Power): Power{}.operator()(y, x); break; - case(UnaryOpType::ClippedRelu): ClippedRelu{}.operator()(y, x); break; - case(UnaryOpType::LeakyRelu): LeakyRelu{}.operator()(y, x); break; - case(UnaryOpType::Elu): Elu{}.operator()(y, x); break; + case(UnaryOpType::Swish): swish_(y, x); break; + case(UnaryOpType::Sigmoid): sigmoid_(y, x); break; + case(UnaryOpType::PassThrough): pass_through_(y, x); break; + case(UnaryOpType::Logistic): logistic_(y, x); break; + case(UnaryOpType::TanH): tanh_(y, x); break; + case(UnaryOpType::Relu): relu_(y, x); break; + case(UnaryOpType::SoftRelu): soft_relu_(y, x); break; + case(UnaryOpType::UnaryAbs): unary_abs_(y, x); break; + case(UnaryOpType::Power): power_(y, x); break; + case(UnaryOpType::ClippedRelu): clipped_relu_(y, x); break; + case(UnaryOpType::LeakyRelu): leaky_relu_(y, x); break; + case(UnaryOpType::Elu): elu_(y, x); break; default: break; } } - template - __device__ __host__ constexpr void isSupported() const + template <> + __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { - - static_assert(std::is_same::value, "X and Y must be of the same type"); - - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Data type is not supported by this operation!"); + float y_float; + float x_float = type_convert(x); + this->operator()(y_float, x_float); + y = type_convert(y_float); } private: @@ -1791,12 +1527,20 @@ struct DynamicUnaryOp public: UnaryOpType unary_op_type_; - UnaryOpBase* unary_op_ptr_ = nullptr; - float alpha; - float beta; - float gamma; + + Swish swish_; + Sigmoid sigmoid_; + PassThrough pass_through_; + Logistic logistic_; + TanH tanh_; + Relu relu_; + SoftRelu soft_relu_; + UnaryAbs unary_abs_; + Power power_; + ClippedRelu clipped_relu_; + LeakyRelu leaky_relu_; + Elu elu_; }; -#pragma clang diagnostic pop } // namespace element_wise } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index 56c37b1b7240de0e85fe2ddd6faa7264deaaa32e..2bc9ef87acfcab6ccd54e4100e69dceb2b2a50d8 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -1,14 +1,17 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/utility/math.hpp" #include "ck/utility/number.hpp" +#include "ck/utility/tuple.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" +#ifndef CK_CODE_GEN_RTC #include #include +#endif namespace ck { @@ -978,8 +981,7 @@ struct BlockToCTileMap_3DGrid_KSplit // Create 3D grid const auto M0 = math::integer_divide_ceil(M, MPerBlock); const auto N0 = math::integer_divide_ceil(N, NPerBlock); - - return std::make_tuple(N0, M0, k_split); + return make_tuple(N0, M0, k_split); } template @@ -1103,7 +1105,7 @@ struct BlockToCTileMap_GemmStreamK uint32_t dp_for_sk_iters = k_iters_per_tile.get(); uint32_t best_sk_score = - std::numeric_limits::max(); // we need to find the smallest sk iters + NumericLimits::Max(); // we need to find the smallest sk iters for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles; tentative_sk_blocks++) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp index 9469fa7bc7b27351f393caed43a9bc12a2b8780d..73bac20e433a1d8d843d2b2b92bd39255ec8e9c7 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp @@ -607,6 +607,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // therefore we may just as well assign Gemm1KPack = group_size + constexpr index_t Gemm1KPack = MfmaSelector::selected_mfma.group_size; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp index e4316774c9752d95e2385c154002a670b40cff9f..1d721b6e4d72e2c41b71a62987aa69c52ca9f013 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -856,11 +856,18 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle static_cast(p_shared) + SharedMemTrait::b1_block_space_offset, b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); - constexpr index_t Gemm1KPack = math::max( - math::lcm( - MfmaSelector::selected_mfma.group_size, - B1K1), - MfmaSelector::selected_mfma.k_per_blk); + // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size + // selected_mfma.k_per_blk <= Gemm1KPack + // + // Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common + // multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case + // Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs + // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will + // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. + // therefore we may just as well assign Gemm1KPack = group_size + + constexpr index_t Gemm1KPack = + MfmaSelector::selected_mfma.group_size; auto blockwise_gemm1 = BlockwiseGemmXdlops_v2< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp index bc76d4cc4fb9e905cdb1547273866b66575ff334..44a488c5ddf20d1b176611df4619606f0a153c31 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp @@ -773,6 +773,7 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // therefore we may just as well assign Gemm1KPack = group_size + constexpr index_t Gemm1KPack = MfmaSelector::selected_mfma.group_size; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index afb2ad2e760396c200930254f871ff13e032a91a..7d2dfab15f123cbe3cc8a03a81467e92161efa13 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -628,6 +628,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // therefore we may just as well assign Gemm1KPack = group_size + constexpr index_t Gemm1KPack = MfmaSelector::selected_mfma.group_size; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index 60c02d64e11cc8be94de6bd6aa8041f2d30b55a4..344656b13f6b5b77ab9ae12e8d02c763ff1d0ca8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -101,7 +101,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; -#if CK_WORKAROUND_DENORM_FIX +#if CK_GFX90A_DENORM_WORKAROUND using AComputeDataType = conditional_t, ck::bhalf_t, AComputeDataType_>; using BComputeDataType = @@ -423,10 +423,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle } template - __host__ __device__ static auto - MakeAsGridDescriptor_M_K(const std::array& MRaws, - const std::array& KRaws, - const std::array& AsStride) + __host__ __device__ static auto MakeAsGridDescriptor_M_K( +#ifdef CK_CODE_GEN_RTC + const ck::Array& MRaws, + const ck::Array& KRaws, + const ck::Array& AsStride +#else + const std::array& MRaws, + const std::array& KRaws, + const std::array& AsStride +#endif + ) { return generate_tuple( [&](auto i) { @@ -462,10 +469,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle } template - __host__ __device__ static auto - MakeBsGridDescriptor_N_K(const std::array& NRaws, - const std::array& KRaws, - const std::array& BsStride) + __host__ __device__ static auto MakeBsGridDescriptor_N_K( +#ifdef CK_CODE_GEN_RTC + const ck::Array& NRaws, + const ck::Array& KRaws, + const ck::Array& BsStride +#else + const std::array& NRaws, + const std::array& KRaws, + const std::array& BsStride +#endif + ) { return generate_tuple( [&](auto i) { @@ -500,10 +514,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle } template - __host__ __device__ static auto - MakeDsGridDescriptor_M_N(const std::array& MRaws, - const std::array& NRaws, - const std::array& DsStride) + __host__ __device__ static auto MakeDsGridDescriptor_M_N( +#ifdef CK_CODE_GEN_RTC + const ck::Array& MRaws, + const ck::Array& NRaws, + const ck::Array& DsStride +#else + const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride +#endif + ) { return generate_tuple( [&](auto i) { @@ -969,9 +990,15 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle const index_t M, const index_t N, const index_t K, +#ifdef CK_CODE_GEN_RTC + const ck::Array StrideAs, + const ck::Array StrideBs, + const ck::Array StrideDs, +#else const std::array StrideAs, const std::array StrideBs, const std::array StrideDs, +#endif const index_t StrideE, const Block2ETileMap& block_2_etile_map) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index e6085fad8c8b13c541beb373114a58c22f6c1a60..eb1eb533d7aa38553ce945bc3452249f9e6d0a50 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -100,7 +100,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; -#if CK_WORKAROUND_DENORM_FIX +#if CK_GFX90A_DENORM_WORKAROUND using AComputeDataType = conditional_t, ck::bhalf_t, AComputeDataType_>; using BComputeDataType = @@ -473,11 +473,19 @@ struct GridwiseGemmMultipleD_xdl_cshuffle return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); } +#ifdef CK_CODE_GEN_RTC + template + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const ck::Array& MRaws, + const ck::Array& NRaws, + const ck::Array& DsStride) +#else template __host__ __device__ static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, const std::array& NRaws, const std::array& DsStride) +#endif { return generate_tuple( [&](auto i) { @@ -941,7 +949,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle const index_t K, const index_t StrideA, const index_t StrideB, +#ifdef CK_CODE_GEN_RTC + const ck::Array StrideDs, +#else const std::array StrideDs, +#endif const index_t StrideE, const Block2ETileMap& block_2_etile_map) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp index cd36b9e51ae6db6c74f491b253ae7099a259d747..b4c5d004c49808273fa0b2f970eb94b18222165f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -164,7 +164,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; -#if CK_WORKAROUND_DENORM_FIX +#if CK_GFX90A_DENORM_WORKAROUND using AComputeDataType = conditional_t, ck::bhalf_t, AComputeDataType_>; #else diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp index 44cbbcd04967064ce36efb7e826d0f2714d69295..9dad66913aec8445c1d99d6d5254517bbd040cdd 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp @@ -1,10 +1,11 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once - +#ifndef CK_CODE_GEN_RTC #include #include +#endif #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" @@ -53,12 +54,15 @@ constexpr auto GridwiseGemmPipeline_Selector() } else { +#ifndef CK_CODE_GEN_RTC std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl; +#endif } } } // namespace ck +#ifndef CK_CODE_GEN_RTC inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p) { switch(p) @@ -71,3 +75,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p) } return os; } +#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp old mode 100644 new mode 100755 index ff10215353cf1b47d598e7dc28bd849ee139e90b..6ef35da485bcb7d28fb685da37e1aa45b13066e3 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp @@ -14,6 +14,8 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/workgroup_barrier.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" namespace ck { @@ -38,7 +40,7 @@ __global__ void __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( - karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg); + karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg, karg.p_workspace_); #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -62,7 +64,13 @@ __global__ void __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run_2Lds( - karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg); + karg.p_a_grid, + karg.p_b_grid, + karg.p_c_grid, + p_shared_0, + p_shared_1, + karg, + karg.p_workspace_); #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -521,7 +529,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, Streamk_sel_, Grid_size_}, p_a_grid{p_a_grid_}, p_b_grid{p_b_grid_}, - p_c_grid{p_c_grid_} + p_c_grid{p_c_grid_}, + block_2_ctile_map_streamk( + M_, N_, AK0Number * CalculateKPadded(K_, 1), Grid_size_, Streamk_sel_) { } @@ -529,6 +539,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 const ADataType* p_a_grid; const BDataType* p_b_grid; CDataType* p_c_grid; + BlockToCTileMap_GemmStreamK_v2 + block_2_ctile_map_streamk; }; struct SplitKBatchOffset @@ -853,6 +870,19 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; } + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})); + } + using BlockwiseGemmPipe = remove_cvref_t(); + constexpr auto NPerBlockReduction = + NPerBlockPow2 / CShuffleBlockTransferScalarPerVector_NPerBlock; + constexpr auto MPerBlockReduction = + (BlockSize + NPerBlockReduction - 1) / NPerBlockReduction; + return Sequence{}; + } + + __host__ __device__ static constexpr auto GetPartialAccBlockDescriptor() + { + const auto c_partial_acc_block_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock), + make_tuple(NPerBlock, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock), + make_tuple(I1, MPerBlock)); + } + }(); + return c_partial_acc_block_m_n; + } using Block2CTileMap_streamk = BlockToCTileMap_GemmStreamK_v2( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M, problem.N, AK0Number * problem.KPadded, problem.Grid_size, problem.Streamk_sel); uint32_t iter_start, iter_end; - bool is_sk_block, is_dp_block; + bool is_sk_block, is_dp_block, is_reduction_block; index_t num_k_block_main_loop; - + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + uint32_t* p_semaphore = reinterpret_cast( + reinterpret_cast(p_workspace) + + block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType))); for(auto block_idx = get_block_1d_id(); block_idx < block_2_ctile_map_streamk.get_grid_dims(); block_idx += gridDim.x) @@ -1163,6 +1241,214 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end); num_k_block_main_loop = iter_end - iter_start; + if constexpr(Block2CTileMap_streamk::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + is_reduction_block = static_cast(block_idx) >= + block_2_ctile_map_streamk.reduction_start_block_idx; + if(is_reduction_block) + { + // descriptors + constexpr auto cluster_length_reduce = GetClusterLengthReduction(); + constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce); + const auto reduce_thread_cluster_idx = + reduce_desc.CalculateBottomIndex(make_multi_index(block_idx)); + const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0]; + const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1]; + + constexpr auto MReduceIters = math::integer_divide_ceil( + Number{}, cluster_length_reduce.At(I0)); + constexpr auto NReduceIters = math::integer_divide_ceil( + Number{}, + cluster_length_reduce.At(I1) * + Number{}); + + constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{})); + constexpr auto acc_thread_buf_store_desc = + make_naive_tensor_descriptor_packed(make_tuple( + I1, I1, I1, Number{})); + + constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor(); + + constexpr auto partial_acc_load_step_n = + make_multi_index(0, + cluster_length_reduce.At(I1) * + CShuffleBlockTransferScalarPerVector_NPerBlock); + constexpr auto partial_acc_load_step_n_reverse = make_multi_index( + 0, + -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) * + CShuffleBlockTransferScalarPerVector_NPerBlock); + constexpr auto partial_acc_load_step_m = + make_multi_index(cluster_length_reduce.At(I0), 0); + + constexpr auto partial_acc_store_step_n = + make_multi_index(0, + 0, + 0, + cluster_length_reduce.At(I1) * + CShuffleBlockTransferScalarPerVector_NPerBlock); + constexpr auto partial_acc_store_step_n_reverse = make_multi_index( + 0, + 0, + 0, + -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) * + CShuffleBlockTransferScalarPerVector_NPerBlock); + constexpr auto partial_acc_store_step_m = + make_multi_index(0, cluster_length_reduce.At(I0), 0, 0); + + StaticBuffer + parcial_acc_buf; + StaticBuffer + acc_buf; + + // start to compute + auto reduction_idx = + block_idx - block_2_ctile_map_streamk.reduction_start_block_idx; + auto spatial_idx = block_2_ctile_map_streamk.tile_to_spatial( + reduction_idx, problem.M, problem.N); + + workgroup_barrier wg_barrier(p_semaphore); + + uint32_t tile_acc_offset_start = + block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx); + uint32_t tile_acc_offset_end = + block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx + + 1); + __syncthreads(); + + auto acc_load = ThreadwiseTensorSliceTransfer_v2< + AccDataType, // SrcData, + AccDataType, // DstData, + decltype(c_partial_acc_block_m_n), // SrcDesc, + decltype(acc_thread_buf_load_desc), // DstDesc, + Sequence<1, + CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths, + Sequence<0, 1>, // DimAccessOrder, + 1, // SrcVectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector, + 1, // SrcScalarStrideInVector, + false // SrcResetCoordinateAfterRun, + >{c_partial_acc_block_m_n, + make_multi_index(thread_m_cluster_id, + thread_n_cluster_id * + CShuffleBlockTransferScalarPerVector_NPerBlock)}; + + auto acc_store = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, // SrcData, + CDataType, // DstData, + decltype(acc_thread_buf_store_desc), // SrcDesc, + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc, + CElementwiseOperation, // ElementwiseOperation, + Sequence<1, + 1, + 1, + CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths, + Sequence<0, 1, 2, 3>, // DimAccessOrder, + 3, // DstVectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // DstScalarPerVector, + InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp, + 1, // DstScalarStrideInVector, + false // DstResetCoordinateAfterRun, + >{c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]), + thread_m_cluster_id, + __builtin_amdgcn_readfirstlane(spatial_idx[I1]), + thread_n_cluster_id * + CShuffleBlockTransferScalarPerVector_NPerBlock), + CElementwiseOperation{}}; + + wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start); + + if(threadIdx.x == 0) + { + p_semaphore[reduction_idx] = 0; + } + using Accumulation = ck::detail:: + AccumulateWithNanCheck; + + for(int i_m = 0; i_m < MReduceIters; i_m++) + { + static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) { + acc_buf.Clear(); + for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++) + { + auto c_partial_acc_buf = + make_dynamic_buffer( + reinterpret_cast(p_workspace) + + i * c_partial_acc_block_m_n.GetElementSpaceSize(), + c_partial_acc_block_m_n.GetElementSpaceSize()); + + acc_load.Run(c_partial_acc_block_m_n, + c_partial_acc_buf, + acc_thread_buf_load_desc, + make_tuple(I0, I0), + parcial_acc_buf); + + static_for<0, CShuffleBlockTransferScalarPerVector_NPerBlock, 1>{}( + [&](auto i_vec) { + constexpr auto offset = + acc_thread_buf_load_desc.CalculateOffset( + make_tuple(0, i_vec)); + Accumulation::Calculate(acc_buf(Number{}), + parcial_acc_buf[Number{}]); + }); + } + + if(thread_n_cluster_id * + CShuffleBlockTransferScalarPerVector_NPerBlock < + NPerBlock) + { + acc_store.Run(acc_thread_buf_store_desc, + make_tuple(I0, I0, I0, I0), + acc_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + if constexpr(NReduceIters != 1) + { + if constexpr(i_n_reduce != (NReduceIters - 1)) + { + acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, + partial_acc_load_step_n); + acc_store.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + partial_acc_store_step_n); + } + else + { + acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, + partial_acc_load_step_n_reverse); + acc_store.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + partial_acc_store_step_n_reverse); + } + } + }); + { + acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, + partial_acc_load_step_m); + acc_store.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + partial_acc_store_step_m); + } + } + + continue; + } + } + + // offset for last acc buffer of this block + uint32_t block_acc_offset = + (block_2_ctile_map_streamk.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * + MPerBlock * NPerBlock; while(true) { uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( @@ -1173,33 +1459,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 iter_end - 1, tile_idx, iter_offset); iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1); - const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M, - problem.MPadded, - problem.K, - problem.KPadded, - problem.StrideA, - problem.AK0); - const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(problem.K, - problem.KPadded, - problem.N, - problem.NPadded, - problem.StrideB, - problem.BK0); - const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); - - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = - MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n, problem.MBlock, problem.NBlock); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); - - const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - auto block_work_idx = block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N); @@ -1363,11 +1622,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle = + GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle(); + auto c_shuffle_block_buf = make_dynamic_buffer( static_cast(p_shared), c_shuffle_block_desc_mblock_mperblock_nblock_nperblock .GetElementSpaceSize()); + auto c_partial_acc_buf = + make_dynamic_buffer( + reinterpret_cast(p_workspace) + block_acc_offset, + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle + .GetElementSpaceSize()); + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, @@ -1477,7 +1745,34 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 c_grid_desc_mblock_mperblock_nblock_nperblock, make_multi_index(block_m_id, 0, block_n_id, 0), c_element_op}; - + // LDS to global partial acc + auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + // InMemoryDataOperationEnum::Set, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * + NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CShuffleDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be + // false, othre wise has scratch + false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be + // false, othre wise has scratch + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, + make_multi_index(0, 0, 0, 0), + c_element_op}; // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = SpaceFillingCurve, @@ -1535,15 +1830,40 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 } else if(is_sk_block) { - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global - .template Run( + if constexpr(Block2CTileMap_streamk::ReductionStrategy == + StreamKReductionStrategy::Atomic) + { + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + else if constexpr(Block2CTileMap_streamk::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + // constexpr offset + c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin( c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); + make_tuple(0, 0, 0, 0)); + + c_block_copy_lds_to_partial_acc.SetDstSliceOrigin( + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, + make_tuple(MXdlPerWave, 0, NXdlPerWave, 0)); + + c_block_copy_lds_to_partial_acc + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, + c_partial_acc_buf); + } } if constexpr(access_id < num_access - 1) @@ -1555,15 +1875,33 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); } }); - } + + if constexpr(Block2CTileMap_streamk::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + if(is_sk_block) + { + // increase the counter for this tile + workgroup_barrier wg_barrier(p_semaphore); + wg_barrier.inc(tile_idx); + } + } + } // shuffle c and write-out end + // exit condition iter_end -= current_iter_length; if(iter_end <= iter_start) break; + if constexpr(Block2CTileMap_streamk::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + block_acc_offset -= MPerBlock * NPerBlock; + } // make sure next loop LDS is ready for use block_sync_lds(); - } - } + } // while loop + + } // for loop } template ( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + uint32_t iter_start, iter_end; - bool is_sk_block, is_dp_block; //, is_padding_block; //, is_reduction_block; + bool is_sk_block, is_dp_block, is_reduction_block; index_t num_k_block_main_loop; + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M, + problem.N, + AK0Number * problem.KPadded, + problem.Grid_size, + problem.Streamk_sel); for(auto block_idx = get_block_1d_id(); block_idx < block_2_ctile_map_streamk.get_grid_dims(); block_idx += gridDim.x) @@ -1601,6 +1963,235 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end); num_k_block_main_loop = iter_end - iter_start; + uint32_t* p_semaphore = reinterpret_cast( + reinterpret_cast(p_workspace) + + block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType))); + + if constexpr(Block2CTileMap_streamk::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + is_reduction_block = static_cast(block_idx) >= + block_2_ctile_map_streamk.reduction_start_block_idx; + if(is_reduction_block) + { + // descriptors + constexpr auto cluster_length_reduce = GetClusterLengthReduction(); + constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce); + const auto reduce_thread_cluster_idx = + reduce_desc.CalculateBottomIndex(make_multi_index(block_idx)); + const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0]; + const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1]; + + constexpr auto MReduceIters = math::integer_divide_ceil( + Number{}, cluster_length_reduce.At(I0)); + constexpr auto NReduceIters = math::integer_divide_ceil( + Number{}, + cluster_length_reduce.At(I1) * + Number{}); + + constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{})); + constexpr auto acc_thread_buf_store_desc = + make_naive_tensor_descriptor_packed(make_tuple( + I1, I1, I1, Number{})); + + constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor(); + + constexpr auto partial_acc_load_step_n = + make_multi_index(0, + cluster_length_reduce.At(I1) * + CShuffleBlockTransferScalarPerVector_NPerBlock); + constexpr auto partial_acc_load_step_n_reverse = make_multi_index( + 0, + -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) * + CShuffleBlockTransferScalarPerVector_NPerBlock); + constexpr auto partial_acc_load_step_m = + make_multi_index(cluster_length_reduce.At(I0), 0); + + constexpr auto partial_acc_store_step_n = + make_multi_index(0, + 0, + 0, + cluster_length_reduce.At(I1) * + CShuffleBlockTransferScalarPerVector_NPerBlock); + constexpr auto partial_acc_store_step_n_reverse = make_multi_index( + 0, + 0, + 0, + -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) * + CShuffleBlockTransferScalarPerVector_NPerBlock); + constexpr auto partial_acc_store_step_m = + make_multi_index(0, cluster_length_reduce.At(I0), 0, 0); + + StaticBuffer + parcial_acc_buf; + StaticBuffer + acc_buf; + + // start to compute + auto reduction_idx = + block_idx - block_2_ctile_map_streamk.reduction_start_block_idx; + auto spatial_idx = block_2_ctile_map_streamk.tile_to_spatial( + reduction_idx, problem.M, problem.N); + + workgroup_barrier wg_barrier(p_semaphore); + + uint32_t tile_acc_offset_start = + block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx); + uint32_t tile_acc_offset_end = + block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx + + 1); + + uint32_t expected_count = tile_acc_offset_end - tile_acc_offset_start; + + if(threadIdx.x == 0) + { + p_semaphore[reduction_idx] = 0; + } + + __syncthreads(); + + auto acc_load = ThreadwiseTensorSliceTransfer_v2< + AccDataType, // SrcData, + AccDataType, // DstData, + decltype(c_partial_acc_block_m_n), // SrcDesc, + decltype(acc_thread_buf_load_desc), // DstDesc, + Sequence<1, + CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths, + Sequence<0, 1>, // DimAccessOrder, + 1, // SrcVectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector, + 1, // SrcScalarStrideInVector, + false // SrcResetCoordinateAfterRun, + >{c_partial_acc_block_m_n, + make_multi_index(thread_m_cluster_id, + thread_n_cluster_id * + CShuffleBlockTransferScalarPerVector_NPerBlock)}; + + auto acc_store = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, // SrcData, + CDataType, // DstData, + decltype(acc_thread_buf_store_desc), // SrcDesc, + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc, + CElementwiseOperation, // ElementwiseOperation, + Sequence<1, + 1, + 1, + CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths, + Sequence<0, 1, 2, 3>, // DimAccessOrder, + 3, // DstVectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // DstScalarPerVector, + InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp, + 1, // DstScalarStrideInVector, + false // DstResetCoordinateAfterRun, + >{c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]), + thread_m_cluster_id, + __builtin_amdgcn_readfirstlane(spatial_idx[I1]), + thread_n_cluster_id * + CShuffleBlockTransferScalarPerVector_NPerBlock), + CElementwiseOperation{}}; + +#if 0 + if(threadIdx.x == 0) { + printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast(blockIdx.x), + reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end), + __builtin_amdgcn_readfirstlane(spatial_idx[I0]), + __builtin_amdgcn_readfirstlane(spatial_idx[I1])); + } +#endif + if(threadIdx.x == 0) + { + atomicAdd(&p_semaphore[reduction_idx], 1); + } + + wg_barrier.wait_eq(p_semaphore[reduction_idx], expected_count); + using Accumulation = ck::detail:: + AccumulateWithNanCheck; + + for(int i_m = 0; i_m < MReduceIters; i_m++) + { + static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) { + acc_buf.Clear(); + for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++) + { + auto c_partial_acc_buf = + make_dynamic_buffer( + reinterpret_cast(p_workspace) + + i * c_partial_acc_block_m_n.GetElementSpaceSize(), + c_partial_acc_block_m_n.GetElementSpaceSize()); + + acc_load.Run(c_partial_acc_block_m_n, + c_partial_acc_buf, + acc_thread_buf_load_desc, + make_tuple(I0, I0), + parcial_acc_buf); + + static_for<0, CShuffleBlockTransferScalarPerVector_NPerBlock, 1>{}( + [&](auto i_vec) { + constexpr auto offset = + acc_thread_buf_load_desc.CalculateOffset( + make_tuple(0, i_vec)); + Accumulation::Calculate(acc_buf(Number{}), + parcial_acc_buf[Number{}]); + }); + } + + if(thread_n_cluster_id * + CShuffleBlockTransferScalarPerVector_NPerBlock < + NPerBlock) + { + acc_store.Run(acc_thread_buf_store_desc, + make_tuple(I0, I0, I0, I0), + acc_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + if constexpr(NReduceIters != 1) + { + if constexpr(i_n_reduce != (NReduceIters - 1)) + { + acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, + partial_acc_load_step_n); + acc_store.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + partial_acc_store_step_n); + } + else + { + acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, + partial_acc_load_step_n_reverse); + acc_store.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + partial_acc_store_step_n_reverse); + } + } + }); + { + acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, + partial_acc_load_step_m); + acc_store.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + partial_acc_store_step_m); + } + } + + continue; + } + } + + // offset for last acc buffer of this block + uint32_t block_acc_offset = + (block_2_ctile_map_streamk.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * + MPerBlock * NPerBlock; + while(true) { uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( @@ -1611,33 +2202,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 iter_end - 1, tile_idx, iter_offset); iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1); - const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M, - problem.MPadded, - problem.K, - problem.KPadded, - problem.StrideA, - problem.AK0); - const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(problem.K, - problem.KPadded, - problem.N, - problem.NPadded, - problem.StrideB, - problem.BK0); - const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); - - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = - MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n, problem.MBlock, problem.NBlock); - - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - auto block_work_idx = block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N); @@ -1811,11 +2375,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle = + GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle(); + auto c_shuffle_block_buf = make_dynamic_buffer( static_cast(p_shared_0), c_shuffle_block_desc_mblock_mperblock_nblock_nperblock .GetElementSpaceSize()); + auto c_partial_acc_buf = + make_dynamic_buffer( + reinterpret_cast(p_workspace) + block_acc_offset, + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle + .GetElementSpaceSize()); + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, @@ -1925,6 +2498,35 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 make_multi_index(block_m_id, 0, block_n_id, 0), c_element_op}; + // LDS to global partial acc + auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + // InMemoryDataOperationEnum::Set, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * + NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CShuffleDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be + // false, othre wise has scratch + false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be + // false, othre wise has scratch + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, + make_multi_index(0, 0, 0, 0), + c_element_op}; + // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = SpaceFillingCurve, @@ -1982,15 +2584,40 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 } else if(is_sk_block) { - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global - .template Run( + if constexpr(Block2CTileMap_streamk::ReductionStrategy == + StreamKReductionStrategy::Atomic) + { + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + else if constexpr(Block2CTileMap_streamk::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + // constexpr offset + c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin( c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); + make_tuple(0, 0, 0, 0)); + + c_block_copy_lds_to_partial_acc.SetDstSliceOrigin( + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, + make_tuple(MXdlPerWave, 0, NXdlPerWave, 0)); + + c_block_copy_lds_to_partial_acc + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, + c_partial_acc_buf); + } } if constexpr(access_id < num_access - 1) { @@ -2002,6 +2629,27 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 } }); } + // exit condition + iter_end -= current_iter_length; + if(iter_end <= iter_start) + break; + if constexpr(Block2CTileMap_streamk::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + block_acc_offset -= MPerBlock * NPerBlock; + } + // make sure next loop LDS is ready for use + block_sync_lds(); + } + if constexpr(Block2CTileMap_streamk::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + if(is_sk_block) + { + // increase the counter for this tile + workgroup_barrier wg_barrier(p_semaphore); + wg_barrier.inc(0); + } } } } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index 36797a906a267cb4a08bf19bc00efa48e1c7e945..a43f0f880ae915a54fb4a59a98f255a14f9c6b48 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -127,7 +127,9 @@ template + typename ComputeTypeB = ComputeTypeA, + bool PermuteA = false, + bool PermuteB = false> struct GridwiseGemm_xdl_cshuffle_v3 { static constexpr auto I0 = Number<0>{}; @@ -151,6 +153,20 @@ struct GridwiseGemm_xdl_cshuffle_v3 using ThisThreadBlock = ThisThreadBlock; + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) { return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); @@ -319,6 +335,10 @@ struct GridwiseGemm_xdl_cshuffle_v3 using GemmSpecialization = tensor_operation::device::GemmSpecialization; + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + if constexpr(GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding) { @@ -373,15 +393,39 @@ struct GridwiseGemm_xdl_cshuffle_v3 } else { - // not pad N or K - const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; + if constexpr(!PermuteB) + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // Pre-shuffled Weight + // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1] + constexpr index_t BK01 = KPerBlock / BK1Value; + const index_t BK0_ = StrideB / BK1Value; + const index_t BK00 = BK0_ / BK01; + + const auto b_grid_desc_bk00_n_bk01_bk1_permute = + make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); + + const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor( + b_grid_desc_bk00_n_bk01_bk1_permute, + make_tuple(make_merge_transform(make_tuple(BK00, BK01)), + make_pass_through_transform(make_tuple(N)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc_bk0_n_bk1_permute; + } } } @@ -572,7 +616,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if constexpr(is_same_v) { - a_k_split_offset = blockIdx.z * karg.KRead; + a_k_split_offset = blockIdx.z * karg.KRead / APackedSize; } else if constexpr(is_same_v) { @@ -585,7 +629,15 @@ struct GridwiseGemm_xdl_cshuffle_v3 } else if constexpr(is_same_v) { - b_k_split_offset = blockIdx.z * karg.KRead; + if constexpr(!PermuteB) + { + b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; + } + else + { + const int k0_offset = karg.KRead * karg.N; + b_k_split_offset = blockIdx.z * k0_offset / BPackedSize; + } } if(blockIdx.z < static_cast(karg.KBatch - 1)) @@ -625,9 +677,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 // in some cases. else if constexpr(is_same::value) { - constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 - ? 1 - : 32 * 4 / KPerBlock / sizeof(ADataType); + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize; + constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( make_tuple( AK0Number * Number{}, Number{}, AK1Number), @@ -761,10 +812,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 else if constexpr(is_same::value) { // NLdsLayer * K0 as logical Bank - constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 - ? 1 - : 32 * 4 / KPerBlock / sizeof(BDataType); - ; + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize; + constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( make_tuple( BK0Number * Number{}, Number{}, BK1Number), @@ -946,8 +995,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - return math::max((a_block_space_size_aligned * sizeof(ADataType) + - b_block_space_size_aligned * sizeof(BDataType)), + return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + + b_block_space_size_aligned * sizeof(BDataType) / BPackedSize), c_block_size * sizeof(CShuffleDataType)); } @@ -1312,8 +1361,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + - a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * + sizeof(ADataType) / + APackedSize), b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); @@ -1706,16 +1756,16 @@ struct GridwiseGemm_xdl_cshuffle_v3 static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf_ping = make_dynamic_buffer( - static_cast(p_shared_0) + - a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + bit_cast(static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType)), b_block_desc_bk0_n_bk1.GetElementSpaceSize()); auto a_block_buf_pong = make_dynamic_buffer( static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf_pong = make_dynamic_buffer( - static_cast(p_shared_1) + - a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + bit_cast(bit_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType)), b_block_desc_bk0_n_bk1.GetElementSpaceSize()); auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp new file mode 100644 index 0000000000000000000000000000000000000000..366a6c59c2e6fef869c11ac971230b3690f01a91 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp @@ -0,0 +1,2208 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/utility/common_header.hpp" + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, + p_shared, + karg); + +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, + p_shared_0, + p_shared_1, + karg); + +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseGemm_xdl_cshuffle_v3 +{ + using BScaleType = ck::half_t; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t KPack = + math::max(math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + if constexpr(!PermuteB) + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // Weight Tile Permute + constexpr index_t BK01 = KPerBlock / BK1Value; + // const index_t BK00 = BK0 / BK01; + const index_t BK0_ = StrideB / BK1Value; + const index_t BK00 = BK0_ / BK01; + + const auto b_grid_desc_bk00_n_bk01_bk1_permute = + make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); + + const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor( + b_grid_desc_bk00_n_bk01_bk1_permute, + make_tuple(make_merge_transform(make_tuple(BK00, BK01)), + make_pass_through_transform(make_tuple(N)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc_bk0_n_bk1_permute; + } + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } +#endif + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t StrideScaleB_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + StrideScaleB{StrideScaleB_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "SScaleB:" << StrideScaleB << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t StrideScaleB; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t StrideScaleB_, + const BScaleType* p_b_scale_grid_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_, + bool is_reduce_ = false) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, StrideScaleB_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_}, + p_b_scale_grid{p_b_scale_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_}, + is_reduce(is_reduce_) + { + } + + __host__ __device__ inline bool IsReduceAdd() const + { + return (Problem::KBatch > 1) && is_reduce; + } + + __host__ __device__ inline bool IsAtomicAdd() const + { + return (Problem::KBatch > 1) && (!is_reduce); + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + + const BScaleType* p_b_scale_grid; + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + bool is_reduce; + }; + + struct SplitKBatchOffset + { + + __device__ SplitKBatchOffset(Argument& karg, index_t k_id) + { + if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead / APackedSize; + } + else if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = k_id * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + if constexpr(!PermuteB) + { + b_k_split_offset = k_id * karg.KRead / BPackedSize; + } + else + { + const int k0_offset = karg.KRead * karg.N; + b_k_split_offset = k_id * k0_offset / BPackedSize; + } + } + + // Calculate B scale offset + if constexpr(is_same_v) + { + scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideB; + } + else if constexpr(is_same_v) + { + scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK); + } + + if(k_id < (karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + + if(karg.IsReduceAdd()) + { + c_reduce_offset = k_id * karg.M * karg.N; + } + else + { + c_reduce_offset = 0; + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t scale_k_split_offset; // New member for scale matrix offset + index_t c_reduce_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize; + constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize; + constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + + b_block_space_size_aligned * sizeof(BDataType) / BPackedSize), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(is_same, half_t>::value || + is_same, float>::value || + is_same, bhalf_t>::value || + is_same, int32_t>::value)) + { + if(!karg.IsReduceAdd()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + if(karg.KBatch > 1) + { + return false; + } + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const BScaleType* p_b_scale_grid, + void* p_shared, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // B Scale buffer + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * + sizeof(ADataType) / + APackedSize), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // b scale + // static_assert(KPerBlock <= ScaleBlockK); + static constexpr auto mfma = MfmaSelector{}; + static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); + static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops(); + static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; + static constexpr auto KPerThread = KPerBlock / K0PerXdlops; + + static constexpr auto ScaleSliceSizeN = NXdlPerWave; + static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK; + static constexpr auto KBlockScaleSliceSizeK = (KPerBlock + ScaleBlockK - 1) / ScaleBlockK; + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + auto b_thread_offset_n = + get_thread_local_1d_id() % NPerXdl + (get_thread_local_1d_id() / 64) % NWaves * NPerXdl; + auto b_thread_offset_k = (get_thread_local_1d_id() % 64) / NPerXdl * KPerThread; + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n, + b_thread_offset_k / ScaleBlockK)); + + constexpr auto b_scale_thread_slice_copy_step = + make_tuple(make_multi_index(NWaves * NPerXdl, 0), + make_multi_index(-NPerBlock, 0), + make_multi_index(-NPerBlock, KBlockScaleSliceSizeK)); + + const index_t num_k_block_per_scale = (ScaleBlockK + KPerBlock - 1) / KPerBlock; + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + b_scale_grid_desc_bn_ak, + b_scale_thread_desc, + b_scale_thread_copy, + b_scale_grid_buf, + b_scale_thread_slice_copy_step, + num_k_block_main_loop, + num_k_block_per_scale); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const BScaleType* p_b_scale_grid, + void* p_shared, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + // B Scale grid + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(problem.StrideScaleB, 1)); + + Run(p_a_grid, + p_b_grid, + p_c_grid, + p_b_scale_grid, + p_shared, + problem, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b_scale_grid_desc_bn_ak, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const BScaleType* p_b_scale_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // B Scale buffer + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + bit_cast(static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType) / APackedSize), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + bit_cast(bit_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType) / APackedSize), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // B scale + static constexpr auto mfma = MfmaSelector{}; + static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); + static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops(); + static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; + static constexpr auto KPerThread = KPerBlock / K0PerXdlops; + + const index_t ScaleSliceSizeN = NXdlPerWave; + static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK; + static constexpr auto KBlockScaleSliceSizeK = (KPerBlock + ScaleBlockK - 1) / ScaleBlockK; + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + auto b_thread_offset_n = + get_thread_local_1d_id() % NPerXdl + (get_thread_local_1d_id() / 64) % NWaves * NPerXdl; + auto b_thread_offset_k = (get_thread_local_1d_id() % 64) / NPerXdl * KPerThread; + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n, + b_thread_offset_k / ScaleBlockK)); + + constexpr auto b_scale_thread_slice_copy_step = + make_tuple(make_multi_index(NWaves * NPerXdl, 0), + make_multi_index(-NPerBlock, 0), + make_multi_index(-NPerBlock, KBlockScaleSliceSizeK)); + + const index_t num_k_block_per_scale = (ScaleBlockK + KPerBlock - 1) / KPerBlock; + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + + b_scale_grid_desc_bn_ak, + b_scale_thread_desc, + b_scale_thread_copy, + b_scale_grid_buf, + b_scale_thread_slice_copy_step, + + num_k_block_main_loop, + num_k_block_per_scale); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const BScaleType* p_b_scale_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(problem.StrideScaleB, 1)); + + Run_2Lds(p_a_grid, + p_b_grid, + p_c_grid, + p_b_scale_grid, + p_shared_0, + p_shared_1, + problem, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b_scale_grid_desc_bn_ak, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index c7038ed4fa358198e447ae4afec7f7291bd46358..e5a31f8d1feaca64c2d14727976dbdd6bb401e48 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -41,7 +41,7 @@ __global__ void #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); GridwiseGemm::template Run( karg.p_a_grid + splitk_batch_offset.a_k_split_offset, @@ -76,7 +76,7 @@ __global__ void __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); GridwiseGemm::template Run_2Lds( karg.p_a_grid + splitk_batch_offset.a_k_split_offset, @@ -639,27 +639,27 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 struct SplitKBatchOffset { - __device__ SplitKBatchOffset(Argument& karg) + __device__ SplitKBatchOffset(Argument& karg, index_t k_id) { if constexpr(is_same_v) { - a_k_split_offset = blockIdx.z * karg.KRead; + a_k_split_offset = k_id * karg.KRead; } else if constexpr(is_same_v) { - a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; + a_k_split_offset = k_id * karg.KRead * karg.StrideA; } if constexpr(is_same_v) { - b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; + b_k_split_offset = k_id * karg.KRead * karg.StrideB; } else if constexpr(is_same_v) { - b_k_split_offset = blockIdx.z * karg.KRead; + b_k_split_offset = k_id * karg.KRead; } - if(blockIdx.z < static_cast(karg.KBatch - 1)) + if(k_id < karg.KBatch - 1) { karg.K = karg.KRead; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp index da6b1b304e58b06b3399208af222ce0ef81d277f..813acfa656d6a240b018053fb44cb41817fc7635 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp @@ -38,8 +38,7 @@ __global__ void // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 5617f67f8be20f9d775f157c6e24804c4a906241..b41e747a3aa562cc7c1bd4295d47b63c0593db4c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -271,7 +271,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight // when mfma if fixed, remove this section and update // FloatAAdjusted -> ComputeTypeA, FloatBAdjusted -> ComputeTypeB, // throughout this file -#if CK_WORKAROUND_DENORM_FIX +#if CK_GFX90A_DENORM_WORKAROUND using FloatAAdjusted = conditional_t, ck::bhalf_t, ComputeTypeA>; using FloatBAdjusted = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index 4f3caff24893fc10239bce5d0d3290a0590ee5ea..5c3d9b7ba4bc0f8899b0ee6eeabbbaa89b503c16 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -254,7 +254,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // we convert fp16->fp32->bf16 and execute bf16 mfma instruction // when mfma if fixed, remove this section and update // FloatABAdjusted -> FloatAB throughout this file -#if CK_WORKAROUND_DENORM_FIX +#if CK_GFX90A_DENORM_WORKAROUND using FloatABAdjusted = conditional_t, ck::bhalf_t, FloatAB>; #else using FloatABAdjusted = FloatAB; diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index d7a6a3624410dea9777c7cf4299ff9c525696194..21315c2567900120eebd5f126046a39da2857ba0 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -212,7 +212,7 @@ template ::type = false> struct ThreadwiseTensorSliceTransfer_v2 { - static_assert((InvalidElementAsNaN && !std::is_integral::value) || + static_assert((InvalidElementAsNaN && !ck::is_integral::value) || (!InvalidElementAsNaN), "Filling invalid element as NaN is only for floating point types"); @@ -1007,6 +1007,13 @@ struct ThreadwiseTensorSliceTransfer_v4 using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + static constexpr index_t PackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + __device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index& src_ref_idx) : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) { @@ -1015,6 +1022,11 @@ struct ThreadwiseTensorSliceTransfer_v4 static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, "wrong! Not divisible"); + + if constexpr(is_same_v, pk_i4_t>) + { + static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1"); + } } template src_tmp_vector; + vector_type_maker_t src_tmp_vector; using src_vector_t = typename decltype(src_tmp_vector)::type; @@ -1120,7 +1132,8 @@ struct ThreadwiseTensorSliceTransfer_v4 if constexpr(SrcBuffer::IsDynamicBuffer()) { src_tmp_vector.template AsType()(Number<0>{}) = - src_buf.template Get(src_data_coord.GetOffset(), is_src_valid); + src_buf.template Get(src_data_coord.GetOffset() / PackedSize, + is_src_valid); } else if constexpr(SrcBuffer::IsStaticBuffer()) { @@ -1133,9 +1146,236 @@ struct ThreadwiseTensorSliceTransfer_v4 }); } - if constexpr(is_same, f8_t>::value && - is_same, half_t>::value && - SrcScalarPerVector % 2 == 0) + if constexpr(is_same, pk_i4_t>::value) + { + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + + constexpr index_t pack_size = 8; + + static_assert(SrcScalarPerVector % pack_size == 0, ""); + + using src_v_t = typename vector_type_maker_t::type; + using dst_v_t = typename vector_type_maker_t::type; + + static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) { + ck::tensor_operation::element_wise::PassThroughPack8{}( + dst_tmp_vector.template AsType()(i), + src_tmp_vector.template AsType()[i]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + } + else if constexpr(is_same, f8_t>::value && + is_same, half_t>::value && + SrcScalarPerVector % 2 == 0) + { + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + + constexpr index_t pack_size = 2; + + using dst_v_t = typename vector_type_maker_t::type; + using src_v_t = typename vector_type_maker_t::type; + static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) { + ck::tensor_operation::element_wise::PassThroughPack2{}( + dst_tmp_vector.template AsType()(i), + src_tmp_vector.template AsType()[i]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + } + else + { + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + + // TODO: if SrcData and DstData are vetor type, then static_cast may not compile + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + dst_tmp_vector.template AsType()(i) = + type_convert(src_tmp_vector.template AsType()[i]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + } + }); + } + + // Fuse scale + template + __device__ void Run(const SrcDesc&, + const SrcRefToOriginDisplacement&, + const SrcBuffer& src_buf, + const DstData& scale, + const DstDesc&, + const DstOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " + "at compile-time"); + + // SrcDesc and DstDesc are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; + + // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time + constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); + constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{}); + + // scalar per access of each dim + constexpr auto src_scalar_per_access = generate_sequence_v2( + [&](auto i) constexpr { + if constexpr(i == SrcVectorDim) + { + return Number{}; + } + else + { + return Number<1>{}; + } + }, + Number{}); + + // scalar step (if steping on SrcVectorDim) of each dim + constexpr auto src_scalar_step_in_vector = generate_sequence_v2( + [&](auto i) constexpr { + if constexpr(i == SrcVectorDim) + { + return Number<1>{}; + } + else + { + return Number<0>{}; + } + }, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + static_ford{}([&](auto ordered_access_idx) { +#if 0 + // TODO: unable to compile + // position in slice window + constexpr auto data_to_origin_disp_idx = + container_reorder_given_old2new(ordered_access_idx, dim_access_order) * + src_scalar_per_access; +#else + // position in slice window + constexpr auto data_to_origin_disp_idx = + ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access; +#endif + // src coordinate + constexpr auto src_ref_to_data_disp_idx = + src_ref_to_origin_disp_idx + data_to_origin_disp_idx; + + constexpr auto src_ref_to_data_disp_coord_step = + make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx); + + auto src_data_coord = src_ref_coord_; + + move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step); + + vector_type_maker_t src_tmp_vector; + + using src_vector_t = typename decltype(src_tmp_vector)::type; + + const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( + src_desc, src_data_coord); + + // copy data from src_buf into src_tmp_vector + if constexpr(SrcBuffer::IsDynamicBuffer()) + { + src_tmp_vector.template AsType()(Number<0>{}) = + src_buf.template Get(src_data_coord.GetOffset() / PackedSize, + is_src_valid); + } + else if constexpr(SrcBuffer::IsStaticBuffer()) + { + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t src_offset = src_desc.CalculateOffset( + src_ref_to_origin_disp_idx + data_to_origin_disp_idx + + i * src_scalar_step_in_vector); + + src_tmp_vector.template AsType()(i) = src_buf[Number{}]; + }); + } + + if constexpr(is_same, pk_i4_t>::value) + { + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + vector_type scale_vector; + scale_vector.template AsType()(Number<0>{}) = scale; + scale_vector.template AsType()(Number<1>{}) = scale; + + constexpr index_t pack_size = 8; + + static_assert(SrcScalarPerVector % pack_size == 0, ""); + + using src_v_t = typename vector_type_maker_t::type; + using dst_v_t = typename vector_type_maker_t::type; + using scale_v_t = typename vector_type_maker_t::type; + + static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) { + ck::tensor_operation::element_wise::DequantPack8{}( + dst_tmp_vector.template AsType()(i), + src_tmp_vector.template AsType()[i], + scale_vector.template AsType()[Number<0>{}]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + } + else if constexpr(is_same, f8_t>::value && + is_same, half_t>::value && + SrcScalarPerVector % 2 == 0) { // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to // DstData) @@ -1304,7 +1544,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic ElementwiseOperation element_op_; }; -// Specilized for WMMA-Navi3 +// Specialized for gfx11 // A single Wave32 is composed by double row // Data exchange allowed between these two rows // This RowLane Dst buf will be filled from two Src buf @@ -1439,7 +1679,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow ElementwiseOperation element_op_{}; }; -// Specilized for WMMA-Navi4 +// Specialized for gfx12 template {}; + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + static constexpr auto I8 = Number<8>{}; + static constexpr auto I10 = Number<10>{}; + static constexpr auto I12 = Number<12>{}; + static constexpr auto I13 = Number<13>{}; + static constexpr auto I14 = Number<14>{}; + static constexpr auto I16 = Number<16>{}; + + static constexpr index_t PackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr auto SrcScalarPerVector = Number{}; + static constexpr auto DstScalarPerVector = Number{}; __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1( const SrcDesc& src_desc, @@ -67,6 +90,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1 src_element_op_(src_element_op), dst_element_op_(dst_element_op) { + if constexpr(is_same_v, pk_i4_t>) + { + static_assert(is_same_v, remove_cvref_t>, + "SrcData != DstData"); + + static_assert( + SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0, + "SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type"); + + static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose"); + } } __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) @@ -95,11 +129,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0, + static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0, "SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"); constexpr auto src_dim_access_order = SrcDimAccessOrder{}; @@ -177,12 +211,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 src_oob_thread_scratch_tuple_(thread_scratch_id) .template SetAsType(src_data_idx_seq, is_src_valid); - using src_vector_type = vector_type_maker_t; - using src_vector_t = typename src_vector_type::type; - - auto src_vector_container = - src_vector_type{src_buf.template Get(src_coord_.GetOffset(), true)}; - using dst_vector_type = vector_type_maker_t; using dst_vector_t = typename dst_vector_type::type; dst_vector_type op_r_v; @@ -193,17 +221,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1 if constexpr(decltype(src_element_op_)::is_pack8_invocable) return math::min(8, SrcScalarPerVector); } - if constexpr(is_detected::value) + else if constexpr(is_detected::value) { if constexpr(decltype(src_element_op_)::is_pack4_invocable) return math::min(4, SrcScalarPerVector); } - if constexpr(is_detected::value) + else if constexpr(is_detected::value) { if constexpr(decltype(src_element_op_)::is_pack2_invocable) return math::min(2, SrcScalarPerVector); } - return 1; + else + { + return 1; + } }; constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); @@ -211,11 +244,63 @@ struct ThreadwiseTensorSliceTransfer_v3r1 using src_elem_op_vec_t = typename vector_type::type; using dst_elem_op_vec_t = typename vector_type::type; - static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) { - // apply the src elementwise op and convert to DstData under the hood if needed - src_element_op_(op_r_v.template AsType()(idx), - src_vector_container.template AsType()[idx]); - }); + using VectorSizeLookupTable = Tuple, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence>; + using VectorOffsetsLookupTable = Tuple, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence>; + + static_for<0, tuple_element_t::Size(), 1>{}( + [&](auto v_idx) { + constexpr auto VectorLoadSize = + tuple_element_t::At(v_idx); + constexpr auto LoadOffset = + tuple_element_t::At(v_idx); + + using src_vector_container = vector_type_maker_t; + using src_vector_container_t = typename src_vector_container::type; + + src_vector_container src_vector = + src_vector_container{src_buf.template Get( + src_coord_.GetOffset() / PackedSize + LoadOffset, true)}; + + static_for<0, VectorLoadSize / elem_op_vec_len, 1>{}([&](auto idx) { + // apply the src elementwise op and convert to DstData under the hood if + // needed + src_element_op_( + op_r_v.template AsType()(idx + LoadOffset), + src_vector.template AsType()[idx]); + }); + }); // copy data from src_vector_container into src_thread_scratch_ src_thread_scratch_tuple_(thread_scratch_id) @@ -276,10 +361,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; }); #else - // OOB Check constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; @@ -350,6 +434,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 (is_same>::value && SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) { + static_assert(!is_same_v, pk_i4_t>, + "in-register transpose is not supported for pk_i4_t"); // each transpose does // DstScalarPerVector # of src vectors in src_thread_scratch_ // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ @@ -410,7 +496,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 } else { - static_ford{}([&](auto idx) { + constexpr auto packed_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access; + + static_ford{}([&](auto idx) { dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; }); } @@ -438,7 +529,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // src scalar per access on each dim // TODO: don't use this constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; @@ -526,13 +617,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // apply DstElementwiseOperation dst_element_op_(dst_v, dst_vector_container.template AsType()[i]); - - dst_vector_container.template AsType()(i) = dst_v; }); // copy data from dst_vector_container to dst_buf dst_buf.template Set( - dst_coord_.GetOffset(), + dst_coord_.GetOffset() / PackedSize, is_dst_valid, dst_vector_container.template AsType()[I0]); @@ -586,7 +675,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; @@ -644,7 +733,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; @@ -730,7 +819,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 __device__ static constexpr auto GetSrcThreadScratchDescriptor() { constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; @@ -779,7 +868,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 __device__ static constexpr auto GetSrcOOBThreadScratchDescriptor() { constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; @@ -790,7 +879,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 { // 1st stage of transforms constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index b435a2a1293c87cd70bee4130e5a15f60bec6bda..1abae56be4d3625e5eaaf978a7350ca011adfe37 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -307,7 +307,7 @@ struct wmma_type{}; - // * Fixed in Navi3x, Will be wave mode dependent on Navi4x + // * Fixed for gfx11, Will be wave mode dependent on gfx12 // static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4; // static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4; // * num_acc_vgprs_per_wave alone M direction diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 24fac91e22a01c7fbf7744ab48e7256d7e2ef900..4f20487b9b7e531e6b284162541934bbebebcfac 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -37,7 +37,17 @@ enum struct MfmaInstr mfma_f32_32x32x16f8bf8, mfma_f32_16x16x32f8bf8, mfma_f32_32x32x16bf8f8, - mfma_f32_16x16x32bf8f8 + mfma_f32_16x16x32bf8f8, + mfma_f32_32x32x16f16, + mfma_f32_16x16x32f16, + mfma_f32_32x32x16bf16, + mfma_f32_16x16x32bf16, + mfma_i32_32x32x32i8, + mfma_i32_16x16x64i8, + mfma_f32_32x32x64f8f6f4, + mfma_f32_16x16x128f8f6f4, + mfma_scale_f32_32x32x64f8f6f4, + mfma_scale_f32_16x16x128f8f6f4 }; template @@ -198,6 +208,50 @@ struct mfma_type } }; +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32f16::Run(a, b, reg_c); + } +}; + template <> struct mfma_type { @@ -264,6 +318,28 @@ struct mfma_type } }; +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16bf16::Run(a, b, reg_c); + } +}; + template <> struct mfma_type { @@ -286,6 +362,28 @@ struct mfma_type } }; +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32bf16::Run(a, b, reg_c); + } +}; + template <> struct mfma_type { @@ -440,6 +538,50 @@ struct mfma_type } }; +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 16; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_i32_32x32x32i8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 16; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_i32_16x16x64i8::Run(a, b, reg_c); + } +}; + template <> struct mfma_type { @@ -638,16 +780,115 @@ struct mfma_type } }; +// TODO: fix mfma...f8f6f4 instructions +template <> +struct mfma_type +{ + // clang-format off + static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size + static constexpr index_t num_threads_per_blk = 32; // n_per_blk + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk + static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? + static constexpr index_t m_per_blk = 32; // from the instruction + static constexpr index_t n_per_blk = 32; // from the instruction + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks + static constexpr bool is_k_reduction = true; // ??? + // clang-format on + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x64f8f6f4::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + // clang-format off + static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size + static constexpr index_t num_threads_per_blk = 16; // == n_per_blk + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk + static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? + static constexpr index_t m_per_blk = 16; // from the instruction + static constexpr index_t n_per_blk = 16; // from the instruction + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks + static constexpr bool is_k_reduction = true; // ??? + // clang-format on + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x128f8f6f4::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + // clang-format off + static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size + static constexpr index_t num_threads_per_blk = 32; // n_per_blk + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk + static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? + static constexpr index_t m_per_blk = 32; // from the instruction + static constexpr index_t n_per_blk = 32; // from the instruction + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks + static constexpr bool is_k_reduction = true; // ??? + // clang-format on + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_scale_f32_32x32x64f8f6f4::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + // clang-format off + static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size + static constexpr index_t num_threads_per_blk = 16; // == n_per_blk + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk + static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? + static constexpr index_t m_per_blk = 16; // from the instruction + static constexpr index_t n_per_blk = 16; // from the instruction + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks + static constexpr bool is_k_reduction = true; // ??? + // clang-format on + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_scale_f32_16x16x128f8f6f4::Run(a, b, reg_c); + } +}; + template + typename additional_type = base_type, + bool is_single_rate_mfma = false> struct MfmaSelector { template + typename additional_type_ = base_type_, + bool is_single_rate_mfma_ = false> static constexpr auto GetMfma(); template <> @@ -711,13 +952,32 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x16f16; +#else + return MfmaInstr::mfma_f32_32x32x8f16; +#endif + } + template <> + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x8f16; } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x32f16; +#else + return MfmaInstr::mfma_f32_16x16x16f16; +#endif + } + + template <> + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x16f16; } @@ -741,7 +1001,19 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x16bf16; +#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP) + return MfmaInstr::mfma_f32_32x32x8bf16_1k; +#else + return MfmaInstr::mfma_f32_32x32x4bf16; +#endif + } + + template <> + constexpr auto GetMfma() { #if defined(CK_USE_AMD_MFMA_BF16_1K_OP) return MfmaInstr::mfma_f32_32x32x8bf16_1k; @@ -751,7 +1023,19 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x32bf16; +#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP) + return MfmaInstr::mfma_f32_16x16x16bf16_1k; +#else + return MfmaInstr::mfma_f32_16x16x8bf16; +#endif + } + + template <> + constexpr auto GetMfma() { #if defined(CK_USE_AMD_MFMA_BF16_1K_OP) return MfmaInstr::mfma_f32_16x16x16bf16_1k; @@ -760,7 +1044,18 @@ struct MfmaSelector #endif } -#if defined(CK_USE_AMD_MFMA_GFX940) +#if defined(__gfx950__) + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_i32_32x32x32i8; + } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_i32_16x16x64i8; + } +#elif defined(__gfx942__) template <> constexpr auto GetMfma() { @@ -832,8 +1127,8 @@ struct MfmaSelector return MfmaInstr::mfma_f32_16x16x32bf8f8; } - static constexpr auto selected_mfma = - mfma_type()>{}; + static constexpr auto selected_mfma = mfma_type< + GetMfma()>{}; __host__ __device__ constexpr MfmaSelector() { @@ -1135,7 +1430,13 @@ struct XdlopsGemm return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td}; } - static constexpr auto mfma = MfmaSelector{}; + // Falls back to single rate instruction on gfx950 if KPack <= 4; no change on gfx942- + static constexpr auto + mfma = MfmaSelector < base_type, + MPerXdlops, NPerXdlops, additional_type, + ((is_same::value || is_same::value) && KPack <= 4) + ? true + : false > {}; static constexpr auto mfma_instr = mfma.selected_mfma; diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp index 2be0b66812434eb5127f5a97534ca4dc36b68273..8df0d885b93141b21b907183ebd9b6505af658a3 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -13,245 +13,614 @@ namespace ck { namespace tensor_operation { -namespace { template < index_t NDimSpatial, + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization, + index_t AK1, + index_t BK1, + index_t GemmMPerBlock, + index_t GemmNPerBlock, + index_t GemmKPerBlock, + bool DoPadGemmM, + bool DoPadGemmN, typename ALayout, - ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization> -constexpr auto make_out_grid_desc(const index_t N, - const index_t Do, - const index_t Ho, - const index_t Wo, - const index_t K, - const std::array& out_g_n_k_wos_strides) + typename BLayout, + typename CLayout, + bool SplitN = false, + typename ADataType = float, + typename CDataType = float, + index_t NumGroupsToMerge = 1, + typename IndexType = index_t> +struct TransformConvBwdDataToGemm_v1 { - const auto KStride = Number<1>{}; + private: + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; - if constexpr(is_same_v) - { - const index_t NStride = out_g_n_k_wos_strides[1]; - const index_t HiStride = out_g_n_k_wos_strides[3]; - const index_t WiStride = out_g_n_k_wos_strides[4]; - if constexpr(ConvBwdDataSpecialization == - ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: - Filter1x1Stride1Pad0) - { + static constexpr auto NonSpatialDimsNum = Number<3>{}; - return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K), - make_tuple(WiStride, KStride)); - } - else + static constexpr auto DIdx = NonSpatialDimsNum; + static constexpr auto HIdx = + NDimSpatial == 2 ? NonSpatialDimsNum : Number{}; + static constexpr auto WIdx = + NDimSpatial == 2 ? Number{} : Number{}; + + static constexpr auto ZIdx = NonSpatialDimsNum; + static constexpr auto YIdx = + NDimSpatial == 2 ? NonSpatialDimsNum : Number{}; + static constexpr auto XIdx = + NDimSpatial == 2 ? Number{} : Number{}; + + template + static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths, + const ConvDimsType& strides, + index_t i) + { + long_index_t acc = 1; + for(; i < (NDimSpatial + 3); i++) { - return make_naive_tensor_descriptor(make_tuple(N, Ho, Wo, K), - make_tuple(NStride, HiStride, WiStride, KStride)); + acc += + static_cast(lengths[i] - I1) * static_cast(strides[i]); } + + return acc; } - else if constexpr(is_same_v) + + template + static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_k_wos_lengths, + const ConvDimsType& a_g_n_k_wos_strides, + const ConvDimsType& c_g_n_c_wis_lengths, + const ConvDimsType& c_g_n_c_wis_strides) { - const index_t NStride = out_g_n_k_wos_strides[1]; - const index_t DoStride = out_g_n_k_wos_strides[3]; - const index_t HoStride = out_g_n_k_wos_strides[4]; - const index_t WoStride = out_g_n_k_wos_strides[5]; - if constexpr(ConvBwdDataSpecialization == - ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: - Filter1x1Stride1Pad0) + const long_index_t a_element_space_size = + calculate_element_space_size_impl(a_g_n_k_wos_lengths, a_g_n_k_wos_strides, I1); + const long_index_t c_element_space_size = + calculate_element_space_size_impl(c_g_n_c_wis_lengths, c_g_n_c_wis_strides, I1); + const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType), + c_element_space_size * sizeof(CDataType)); + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + const IndexType N = a_g_n_k_wos_lengths[I1]; + + if(element_space_size > TwoGB) { + // Minimum divisor of N to not exceed 2GB + const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB); - return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K), - make_tuple(WoStride, KStride)); + if(divisor <= static_cast(N)) + { + // Find least divisor of N larger than element_space_size / TwoGB + // Iterate up to sqrt(N). There are no divisors above this value. + for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N; + least_divisor++) + { + if(N % least_divisor == 0) + { + return N / least_divisor; + } + } + // Not found, process one Convolution N per block + return 1; + } + else + { + // Not possible to support even after split N. + // Too large tensor. + return N; + } } else { - return make_naive_tensor_descriptor( - make_tuple(N, Do, Ho, Wo, K), - make_tuple(NStride, DoStride, HoStride, WoStride, KStride)); + // Split N is not needed. + return N; } } - else if constexpr(is_same_v) + + public: + __host__ __device__ constexpr TransformConvBwdDataToGemm_v1() {} + + template + __host__ __device__ TransformConvBwdDataToGemm_v1( + const TransformConvBwdDataToGemm_v1Base& transform_conv_bwd_data_to_gemm_base) + : N_{static_cast(transform_conv_bwd_data_to_gemm_base.N_)}, + Di_{static_cast(transform_conv_bwd_data_to_gemm_base.Di_)}, + Hi_{static_cast(transform_conv_bwd_data_to_gemm_base.Hi_)}, + Wi_{static_cast(transform_conv_bwd_data_to_gemm_base.Wi_)}, + Do_{static_cast(transform_conv_bwd_data_to_gemm_base.Do_)}, + Ho_{static_cast(transform_conv_bwd_data_to_gemm_base.Ho_)}, + Wo_{static_cast(transform_conv_bwd_data_to_gemm_base.Wo_)}, + Z_{static_cast(transform_conv_bwd_data_to_gemm_base.Z_)}, + Y_{static_cast(transform_conv_bwd_data_to_gemm_base.Y_)}, + X_{static_cast(transform_conv_bwd_data_to_gemm_base.X_)}, + K_{static_cast(transform_conv_bwd_data_to_gemm_base.K_)}, + C_{static_cast(transform_conv_bwd_data_to_gemm_base.C_)}, + DiStride_{static_cast(transform_conv_bwd_data_to_gemm_base.DiStride_)}, + HiStride_{static_cast(transform_conv_bwd_data_to_gemm_base.HiStride_)}, + WiStride_{static_cast(transform_conv_bwd_data_to_gemm_base.WiStride_)}, + DoStride_{static_cast(transform_conv_bwd_data_to_gemm_base.DoStride_)}, + HoStride_{static_cast(transform_conv_bwd_data_to_gemm_base.HoStride_)}, + WoStride_{static_cast(transform_conv_bwd_data_to_gemm_base.WoStride_)}, + CStrideTensorB_{ + static_cast(transform_conv_bwd_data_to_gemm_base.CStrideTensorB_)}, + CStrideTensorC_{ + static_cast(transform_conv_bwd_data_to_gemm_base.CStrideTensorC_)}, + KStrideTensorA_{ + static_cast(transform_conv_bwd_data_to_gemm_base.KStrideTensorA_)}, + KStrideTensorB_{ + static_cast(transform_conv_bwd_data_to_gemm_base.KStrideTensorB_)}, + NStrideTensorA_{ + static_cast(transform_conv_bwd_data_to_gemm_base.NStrideTensorA_)}, + NStrideTensorC_{ + static_cast(transform_conv_bwd_data_to_gemm_base.NStrideTensorC_)}, + ConvStrideD_{static_cast(transform_conv_bwd_data_to_gemm_base.ConvStrideD_)}, + ConvStrideH_{static_cast(transform_conv_bwd_data_to_gemm_base.ConvStrideH_)}, + ConvStrideW_{static_cast(transform_conv_bwd_data_to_gemm_base.ConvStrideW_)}, + ConvDilationD_{ + static_cast(transform_conv_bwd_data_to_gemm_base.ConvDilationD_)}, + ConvDilationH_{ + static_cast(transform_conv_bwd_data_to_gemm_base.ConvDilationH_)}, + ConvDilationW_{ + static_cast(transform_conv_bwd_data_to_gemm_base.ConvDilationW_)}, + InLeftPadD_{static_cast(transform_conv_bwd_data_to_gemm_base.InLeftPadD_)}, + InLeftPadH_{static_cast(transform_conv_bwd_data_to_gemm_base.InLeftPadH_)}, + InLeftPadW_{static_cast(transform_conv_bwd_data_to_gemm_base.InLeftPadW_)}, + InRightPadD_{static_cast(transform_conv_bwd_data_to_gemm_base.InRightPadD_)}, + InRightPadH_{static_cast(transform_conv_bwd_data_to_gemm_base.InRightPadH_)}, + InRightPadW_{static_cast(transform_conv_bwd_data_to_gemm_base.InRightPadW_)}, + IdxZTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.IdxZTilde_)}, + IdxYTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.IdxYTilde_)}, + IdxXTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.IdxXTilde_)}, + GcdStrideDilationD_{ + static_cast(transform_conv_bwd_data_to_gemm_base.GcdStrideDilationD_)}, + GcdStrideDilationH_{ + static_cast(transform_conv_bwd_data_to_gemm_base.GcdStrideDilationH_)}, + GcdStrideDilationW_{ + static_cast(transform_conv_bwd_data_to_gemm_base.GcdStrideDilationW_)}, + ZTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.ZTilde_)}, + YTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.YTilde_)}, + XTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.XTilde_)}, + DTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.DTilde_)}, + HTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.HTilde_)}, + WTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.WTilde_)}, + ZDot_{static_cast(transform_conv_bwd_data_to_gemm_base.ZDot_)}, + YDot_{static_cast(transform_conv_bwd_data_to_gemm_base.YDot_)}, + XDot_{static_cast(transform_conv_bwd_data_to_gemm_base.XDot_)} { - // assume packed - if constexpr(ConvBwdDataSpecialization == - ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: - Filter1x1Stride1Pad0) + } + + template + __host__ __device__ + TransformConvBwdDataToGemm_v1(const ConvDimsType& a_g_n_k_wos_lengths, + const ConvDimsType& a_g_n_k_wos_strides, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& b_g_k_c_xs_strides, + const ConvDimsType& c_g_n_c_wis_lengths, + const ConvDimsType& c_g_n_c_wis_strides, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads, + const ConvSpatialDimsType& tildes) + : Hi_{c_g_n_c_wis_lengths[HIdx]}, + Wi_{c_g_n_c_wis_lengths[WIdx]}, + Ho_{a_g_n_k_wos_lengths[HIdx]}, + Wo_{a_g_n_k_wos_lengths[WIdx]}, + Y_{b_g_k_c_xs_lengths[YIdx]}, + X_{b_g_k_c_xs_lengths[XIdx]}, + K_{a_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + HiStride_{c_g_n_c_wis_strides[HIdx]}, + WiStride_{c_g_n_c_wis_strides[WIdx]}, + HoStride_{a_g_n_k_wos_strides[HIdx]}, + WoStride_{a_g_n_k_wos_strides[WIdx]}, + CStrideTensorB_{b_g_k_c_xs_strides[I2]}, + CStrideTensorC_{c_g_n_c_wis_strides[I2]}, + KStrideTensorA_{a_g_n_k_wos_strides[I2]}, + KStrideTensorB_{b_g_k_c_xs_strides[I1]}, + NStrideTensorA_{a_g_n_k_wos_strides[I1]}, + NStrideTensorC_{c_g_n_c_wis_strides[I1]}, + ConvStrideH_{conv_filter_strides[HIdx - NonSpatialDimsNum]}, + ConvStrideW_{conv_filter_strides[WIdx - NonSpatialDimsNum]}, + ConvDilationH_{conv_filter_dilations[HIdx - NonSpatialDimsNum]}, + ConvDilationW_{conv_filter_dilations[WIdx - NonSpatialDimsNum]}, + InLeftPadH_{input_left_pads[HIdx - NonSpatialDimsNum]}, + InLeftPadW_{input_left_pads[WIdx - NonSpatialDimsNum]}, + InRightPadH_{input_right_pads[HIdx - NonSpatialDimsNum]}, + InRightPadW_{input_right_pads[WIdx - NonSpatialDimsNum]}, + IdxYTilde_{tildes[YIdx - NonSpatialDimsNum]}, + IdxXTilde_{tildes[XIdx - NonSpatialDimsNum]} + { + static_assert(is_same_v> || + is_same_v>); + static_assert(is_same_v> || + is_same_v>); + + if constexpr(SplitN) { - return make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + N_ = GetSplitedNSize( + a_g_n_k_wos_lengths, a_g_n_k_wos_strides, c_g_n_c_wis_lengths, c_g_n_c_wis_strides); } else { - return make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); + N_ = c_g_n_c_wis_lengths[I1]; } - } - else if constexpr(is_same_v) - { - // assume packed - if constexpr(ConvBwdDataSpecialization == - ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: - Filter1x1Stride1Pad0) + if constexpr(NDimSpatial == 3) { - return make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); + Di_ = c_g_n_c_wis_lengths[DIdx]; + Do_ = a_g_n_k_wos_lengths[DIdx]; + Z_ = b_g_k_c_xs_lengths[ZIdx]; + DiStride_ = c_g_n_c_wis_strides[DIdx]; + DoStride_ = a_g_n_k_wos_strides[DIdx]; + ConvStrideD_ = conv_filter_strides[DIdx - NonSpatialDimsNum]; + ConvDilationD_ = conv_filter_dilations[DIdx - NonSpatialDimsNum]; + InLeftPadD_ = input_left_pads[DIdx - NonSpatialDimsNum]; + InRightPadD_ = input_right_pads[DIdx - NonSpatialDimsNum]; + IdxZTilde_ = tildes[ZIdx - NonSpatialDimsNum]; + GcdStrideDilationD_ = math::gcd(ConvStrideD_, ConvDilationD_); + ZTilde_ = ConvStrideD_ / GcdStrideDilationD_; + DTilde_ = Do_ + math::integer_divide_ceil(ConvDilationD_ * (Z_ - I1), ConvStrideD_); + ZDot_ = math::integer_divide_ceil(Z_, ZTilde_); } else { - return make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K)); + Di_ = Do_ = Z_ = ZTilde_ = ConvStrideD_ = DTilde_ = ZDot_ = 1; + InLeftPadD_ = InRightPadD_ = DiStride_ = DoStride_ = IdxZTilde_ = 0; } - } - else - { - throw std::runtime_error("wrong! unsupported layout: " + ALayout::name()); - } -} -template -constexpr auto make_wei_grid_desc( - const index_t K, const index_t Z, const index_t Y, const index_t X, const index_t C) -{ + GcdStrideDilationH_ = math::gcd(ConvStrideH_, ConvDilationH_); + GcdStrideDilationW_ = math::gcd(ConvStrideW_, ConvDilationW_); - if constexpr(is_same_v) - { - return make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C)); - } - else if constexpr(is_same_v) - { - return make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C)); - } - else - { - throw std::runtime_error("wrong! unsupported layout: " + BLayout::name()); - } -} - -template -constexpr auto make_in_grid_desc(const index_t N, - const index_t Di, - const index_t Hi, - const index_t Wi, - const index_t C, - const std::array& in_g_n_c_wis_strides) -{ + YTilde_ = ConvStrideH_ / GcdStrideDilationH_; + XTilde_ = ConvStrideW_ / GcdStrideDilationW_; - if constexpr(is_same_v || - is_same_v || - is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C), - make_tuple(in_g_n_c_wis_strides[1], - in_g_n_c_wis_strides[3], - in_g_n_c_wis_strides[4], - in_g_n_c_wis_strides[2])); + HTilde_ = Ho_ + math::integer_divide_ceil(ConvDilationH_ * (Y_ - I1), ConvStrideH_); + WTilde_ = Wo_ + math::integer_divide_ceil(ConvDilationW_ * (X_ - I1), ConvStrideW_); + + YDot_ = math::integer_divide_ceil(Y_, YTilde_); + XDot_ = math::integer_divide_ceil(X_, XTilde_); } - else if constexpr(is_same_v || - is_same_v) + +#if 0 // At now not supported to split tensor + __host__ bool AreDescriptorsSmallerThan2GB() const { - return make_naive_tensor_descriptor(make_tuple(N, Di, Hi, Wi, C), - make_tuple(in_g_n_c_wis_strides[1], - in_g_n_c_wis_strides[3], - in_g_n_c_wis_strides[4], - in_g_n_c_wis_strides[5], - in_g_n_c_wis_strides[2])); + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + const long_index_t in_desc_space_size = + I1 + (N_ - I1) * NStrideTensorC_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ + + (Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorC_; + const long_index_t out_desc_space_size = + I1 + (N_ - I1) * NStrideTensorA_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ + + (Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorA_; + + bool is_a_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(ADataType)) <= TwoGB; + bool is_c_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(CDataType)) <= TwoGB; + + return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB; } - else + + __host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base, + CDataType* c_grid_ptr_base) const { - throw std::runtime_error("wrong! unsupported layout: " + CLayout::name()); - } -} + // Create copies + auto conv_to_gemm_transformer_left = *this; + auto conv_to_gemm_transformer_right = *this; + IndexType a_right_offset = 0; + IndexType c_right_offset = 0; + // Calculate real filter size + const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1; + const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1; + const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1; + // Calculate start position in input for right tensor + const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_; + const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_; + const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_; + // Calculate last position in input for left tensor + const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff; + const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff; + const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff; + // Allow to split if whole left padding will be in left tensor and right padding in right + // tensor + const bool is_possible_to_split_d = Do_ != 1 && + di_right_transformer_start_idx > InLeftPadD_ && + di_left_transformer_end_idx <= (InLeftPadD_ + Di_); + const bool is_possible_to_split_h = Ho_ != 1 && + hi_right_transformer_start_idx > InLeftPadH_ && + hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_); + const bool is_possible_to_split_w = Wo_ != 1 && + wi_right_transformer_start_idx > InLeftPadW_ && + wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_); + + if(is_possible_to_split_d) + { + // Apply new sizes + // Split output on half + conv_to_gemm_transformer_left.Do_ = Do_ / 2; + conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2; + // Assign left padding to left convolution + conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_; + conv_to_gemm_transformer_right.InLeftPadD_ = 0; + // Assign right padding to right convolution + conv_to_gemm_transformer_left.InRightPadD_ = 0; + conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_; + // Calculate new input size + conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_; + conv_to_gemm_transformer_right.Di_ = + math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_), + (conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff); + ; + // Calcualte offsets + a_right_offset = (Do_ / 2) * DoStride_; + c_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_; + } + else if(is_possible_to_split_h) + { + conv_to_gemm_transformer_left.Ho_ = Ho_ / 2; + conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2; -} // namespace + conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_; + conv_to_gemm_transformer_right.InLeftPadH_ = 0; -template < - index_t NDimSpatial, - ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization, - index_t AK1, - index_t BK1, - index_t GemmMPerBlock, - index_t GemmNPerBlock, - index_t GemmKPerBlock, - bool DoPadGemmM, - bool DoPadGemmN> -struct TransformConvBwdDataToGemm_v1 -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; + conv_to_gemm_transformer_left.InRightPadH_ = 0; + conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_; - static constexpr auto NonSpatialDimsNum = Number<3>{}; + conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_; + conv_to_gemm_transformer_right.Hi_ = + math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_), + (conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff); + a_right_offset = (Ho_ / 2) * HoStride_; + c_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_; + } + else if(is_possible_to_split_w) + { + conv_to_gemm_transformer_left.Wo_ = Wo_ / 2; + conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2; - static constexpr auto DIdx = Number{}; - static constexpr auto HIdx = - NDimSpatial == 2 ? Number{} : Number{}; - static constexpr auto WIdx = - NDimSpatial == 2 ? Number{} : Number{}; + conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_; + conv_to_gemm_transformer_right.InLeftPadW_ = 0; - static constexpr auto ZIdx = Number{}; - static constexpr auto YIdx = - NDimSpatial == 2 ? Number{} : Number{}; - static constexpr auto XIdx = - NDimSpatial == 2 ? Number{} : Number{}; + conv_to_gemm_transformer_left.InRightPadW_ = 0; + conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_; - template || - is_same_v || - is_same_v || - is_same_v), - bool>::type = false> - static auto MakeADescriptor_AK0_M_AK1( - const std::array& out_g_n_k_wos_lengths, - const std::array& out_g_n_k_wos_strides, - const std::array& wei_g_k_c_xs_lengths, - const std::array& /* wei_g_k_c_xs_strides */, - const std::array& in_g_n_c_wis_lengths, - const std::array& /* in_g_n_c_wis_strides */, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& /* input_right_pads */, - const std::array& tildes) + conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_; + conv_to_gemm_transformer_right.Wi_ = + math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_), + (conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff); + + a_right_offset = (Wo_ / 2) * WoStride_; + c_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_; + } + // Return left transform, right transformer, right offset to Input and right offset to + // Output + return ck::make_tuple(conv_to_gemm_transformer_left, + conv_to_gemm_transformer_right, + a_grid_ptr_base + a_right_offset, + c_grid_ptr_base + c_right_offset); + } + + __host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base, + CDataType* c_grid_ptr_base) const { - index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum]; - index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum]; - index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum]; + // Create copies + auto conv_to_gemm_transformer_left = *this; + auto conv_to_gemm_transformer_right = *this; + IndexType a_right_offset = 0; + IndexType c_right_offset = 0; + + // Calculate start position in input for right tensor + const IndexType do_right_transformer_start_idx = math::integer_divide_ceil((Di_ / 2) + InLeftPadD_ - ((Z_ - 1) * ConvDilationD_), ConvStrideD_); + const IndexType ho_right_transformer_start_idx = math::integer_divide_ceil((Hi_ / 2) + InLeftPadH_ - ((Y_ - 1) * ConvDilationH_), ConvStrideH_); + const IndexType wo_right_transformer_start_idx = math::integer_divide_ceil((Wi_ / 2) + InLeftPadW_ - ((X_ - 1) * ConvDilationW_), ConvStrideW_); + // Calculate last position in input for left tensor + const IndexType do_left_transformer_end_idx = math::integer_divide_ceil((Di_ / 2 - 1) + InLeftPadD_, ConvStrideD_); + const IndexType ho_left_transformer_end_idx = math::integer_divide_ceil((Hi_ / 2 - 1) + InLeftPadH_, ConvStrideH_); + const IndexType wo_left_transformer_end_idx = math::integer_divide_ceil((Wi_ / 2 - 1) + InLeftPadW_, ConvStrideW_); + + + if(Di_!=1) + { + // Apply new sizes + // Split output on half + conv_to_gemm_transformer_left.Di_ = Di_ / 2; + conv_to_gemm_transformer_right.Di_ = Di_ - Di_ / 2; + // Assign left padding to left convolution + conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_; + conv_to_gemm_transformer_right.InLeftPadD_ = 0; + // // Assign right padding to right convolution + conv_to_gemm_transformer_left.InRightPadD_ = 0; + conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_; + // Calculate new input size + conv_to_gemm_transformer_left.Do_ = do_left_transformer_end_idx; + conv_to_gemm_transformer_right.Do_ = Do_ - do_right_transformer_start_idx; + ; + // Calcualte offsets + a_right_offset = do_right_transformer_start_idx * DoStride_; + c_right_offset = (Di_ / 2) * DiStride_; + } + else if(Hi_!=1) + { + // Apply new sizes + // Split output on half + conv_to_gemm_transformer_left.Hi_ = Hi_ / 2; + conv_to_gemm_transformer_right.Hi_ = Hi_ - Hi_ / 2; + // Assign left padding to left convolution + conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_; + conv_to_gemm_transformer_right.InLeftPadH_ = 0; + // // Assign right padding to right convolution + conv_to_gemm_transformer_left.InRightPadH_ = 0; + conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_; + // Calculate new input size + conv_to_gemm_transformer_left.Ho_ = ho_left_transformer_end_idx ; + conv_to_gemm_transformer_right.Ho_ = Ho_ - ho_right_transformer_start_idx ; + ; + // Calcualte offsets + a_right_offset = ho_right_transformer_start_idx * HoStride_; + c_right_offset = (Hi_ / 2) * HiStride_; + } + else if(Wi_!=1) + { + // Apply new sizes + // Split output on half + conv_to_gemm_transformer_left.Wi_ = Wi_ / 2; + conv_to_gemm_transformer_right.Wi_ = Wi_ - Wi_ / 2; + // Assign left padding to left convolution + conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_; + conv_to_gemm_transformer_right.InLeftPadW_ = 0; + // Assign right padding to right convolution + conv_to_gemm_transformer_left.InRightPadW_ = 0; + conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_; + // Calculate new input size + conv_to_gemm_transformer_left.Wo_ = wo_left_transformer_end_idx; + conv_to_gemm_transformer_right.Wo_ = Wo_ - wo_right_transformer_start_idx; + ; + // Calcualte offsets + a_right_offset = wo_right_transformer_start_idx * WoStride_; + c_right_offset = (Wi_ / 2) * WiStride_; + } + // Return left transform, right transformer, right offset to Input and right offset to + // Output + return ck::make_tuple(conv_to_gemm_transformer_left, + conv_to_gemm_transformer_right, + a_grid_ptr_base + a_right_offset, + c_grid_ptr_base + c_right_offset); + } +#endif - const index_t N = in_g_n_c_wis_lengths[1]; - const index_t K = wei_g_k_c_xs_lengths[1]; + __host__ __device__ auto MakeOutGridDesc() const + { + if constexpr(is_same_v) + { + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { - const index_t Di = NDimSpatial == 3 ? in_g_n_c_wis_lengths[DIdx] : 1; - const index_t Hi = in_g_n_c_wis_lengths[HIdx]; - const index_t Wi = in_g_n_c_wis_lengths[WIdx]; + return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), + make_tuple(WoStride_, KStrideTensorA_)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N_, Ho_, Wo_, K_), + make_tuple(NStrideTensorA_, HoStride_, WoStride_, KStrideTensorA_)); + } + } + else if constexpr(is_same_v) + { + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { - const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1; - const index_t Ho = out_g_n_k_wos_lengths[HIdx]; - const index_t Wo = out_g_n_k_wos_lengths[WIdx]; + return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_), + make_tuple(WoStride_, KStrideTensorA_)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N_, Do_, Ho_, Wo_, K_), + make_tuple(NStrideTensorA_, DoStride_, HoStride_, WoStride_, KStrideTensorA_)); + } + } + else if constexpr(is_same_v) + { + // assume packed + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor_packed(make_tuple(N_ * Ho_ * Wo_, K_)); + } + else + { + return make_naive_tensor_descriptor_packed(make_tuple(N_, Ho_, Wo_, K_)); + } + } + else if constexpr(is_same_v) + { + // assume packed + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor_packed(make_tuple(N_ * Do_ * Ho_ * Wo_, K_)); + } + else + { + return make_naive_tensor_descriptor_packed(make_tuple(N_, Do_, Ho_, Wo_, K_)); + } + } + else + { + throw std::runtime_error("wrong! unsupported layout: " + ALayout::name()); + } + } - const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1; - const index_t Y = wei_g_k_c_xs_lengths[YIdx]; - const index_t X = wei_g_k_c_xs_lengths[XIdx]; + __host__ __device__ auto MakeWeiGridDesc() const + { - const index_t InLeftPadD = input_left_pads[DIdx - NonSpatialDimsNum]; - const index_t InLeftPadH = input_left_pads[HIdx - NonSpatialDimsNum]; - const index_t InLeftPadW = input_left_pads[WIdx - NonSpatialDimsNum]; + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor_packed(make_tuple(K_, Y_, X_, C_)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor_packed(make_tuple(K_, Z_, Y_, X_, C_)); + } + else + { + throw std::runtime_error("wrong! unsupported layout: " + BLayout::name()); + } + } - const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum]; - const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum]; - const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum]; + __host__ __device__ auto MakeInGridDesc() const + { - const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum]; - const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; - const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; + if constexpr(is_same_v || + is_same_v || + is_same_v) + { + return make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, C_), + make_tuple(NStrideTensorC_, HiStride_, WiStride_, CStrideTensorC_)); + } + else if constexpr(is_same_v || + is_same_v) + { + return make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, C_), + make_tuple(NStrideTensorC_, DiStride_, HiStride_, WiStride_, CStrideTensorC_)); + } + else + { + throw std::runtime_error("wrong! unsupported layout: " + CLayout::name()); + } + } + template < + typename ALayout_ = ALayout, + typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) && + (is_same_v || + is_same_v || + is_same_v || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeADescriptor_AK0_M_AK1() const + { // n_do_ho_wo_k for 3d or n_ho_wo_k for 2d - const auto out_grid_desc = - make_out_grid_desc( - N, Do, Ho, Wo, K, out_g_n_k_wos_strides); + const auto out_grid_desc = MakeOutGridDesc(); if constexpr(ConvBwdDataSpecialization == ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: Filter1x1Stride1Pad0) { - const index_t AK0 = math::integer_divide_ceil(K, AK1); + const index_t AK0 = math::integer_divide_ceil(K_, AK1); // A: output tensor const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( out_grid_desc, - make_tuple(make_pass_through_transform(N * Do * Ho * Wo), + make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_), make_unmerge_transform(make_tuple(AK0, AK1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{})); @@ -266,82 +635,63 @@ struct TransformConvBwdDataToGemm_v1 } else { - const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); - const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); - const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - const auto ZTilde = ConvStrideD / GcdStrideDilationD; - const auto YTilde = ConvStrideH / GcdStrideDilationH; - const auto XTilde = ConvStrideW / GcdStrideDilationW; - - const auto ZDot = math::integer_divide_ceil(Z, ZTilde); - const auto YDot = math::integer_divide_ceil(Y, YTilde); - const auto XDot = math::integer_divide_ceil(X, XTilde); - - const auto DTilde = - Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD); - const auto HTilde = - Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); - const auto WTilde = - Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); - // only work on HTilde and WTilde that contribute to non-padding area of input tensor const auto IDTildeSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD); + math::max(I0, InLeftPadD_ - ConvDilationD_ * (ZTilde_ - I1)), ConvStrideD_); const auto IHTildeSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + math::max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_); const auto IWTildeSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + math::max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_); const auto IDTildeSliceEnd = math::min( - DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1); + DTilde_, math::integer_divide_ceil(InLeftPadD_ + Di_ - I1, ConvStrideD_) + I1); const auto IHTildeSliceEnd = math::min( - HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + HTilde_, math::integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1); const auto IWTildeSliceEnd = math::min( - WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + WTilde_, math::integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1); const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin; const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; // GemmK is different for each GEMM - const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); - const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); - const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + const auto ZDotSlice = math::integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_); + const auto YDotSlice = math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_); + const auto XDotSlice = math::integer_divide_ceil(X_ - IdxXTilde_, XTilde_); if constexpr(NDimSpatial == 2) { // A: output tensor const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( out_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Ho, I0, I0), - make_pad_transform(Wo, I0, I0), - make_pass_through_transform(K)), + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Ho_, I0, I0), + make_pad_transform(Wo_, I0, I0), + make_pass_through_transform(K_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( out_n_hop_wop_k_grid_desc, make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(YDot, HTilde), - make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), - make_embed_transform(make_tuple(XDot, WTilde), - make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), - make_pass_through_transform(K)), + make_pass_through_transform(N_), + make_embed_transform(make_tuple(YDot_, HTilde_), + make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), + make_embed_transform(make_tuple(XDot_, WTilde_), + make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(K_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = transform_tensor_descriptor( out_n_ydot_htilde_xdot_wtilde_k_grid_desc, - make_tuple(make_pass_through_transform(N), - make_slice_transform(YDot, I0, YDotSlice), - make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), - make_slice_transform(XDot, I0, XDotSlice), - make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), - make_pass_through_transform(K)), + make_tuple(make_pass_through_transform(N_), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, @@ -357,8 +707,8 @@ struct TransformConvBwdDataToGemm_v1 const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor( out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, - make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)), - make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice))), + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), + make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice))), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -385,11 +735,11 @@ struct TransformConvBwdDataToGemm_v1 // A: output tensor const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( out_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Do, I0, I0), - make_pad_transform(Ho, I0, I0), - make_pad_transform(Wo, I0, I0), - make_pass_through_transform(K)), + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Do_, I0, I0), + make_pad_transform(Ho_, I0, I0), + make_pad_transform(Wo_, I0, I0), + make_pass_through_transform(K_)), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple( @@ -398,17 +748,17 @@ struct TransformConvBwdDataToGemm_v1 const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( out_n_hop_wop_k_grid_desc, - make_tuple(make_pass_through_transform(N), + make_tuple(make_pass_through_transform(N_), make_embed_transform( - make_tuple(ZDot, DTilde), - make_tuple(-ConvDilationD / GcdStrideDilationD, I1)), + make_tuple(ZDot_, DTilde_), + make_tuple(-ConvDilationD_ / GcdStrideDilationD_, I1)), make_embed_transform( - make_tuple(YDot, HTilde), - make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_tuple(YDot_, HTilde_), + make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), make_embed_transform( - make_tuple(XDot, WTilde), - make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), - make_pass_through_transform(K)), + make_tuple(XDot_, WTilde_), + make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(K_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, @@ -424,14 +774,15 @@ struct TransformConvBwdDataToGemm_v1 out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = transform_tensor_descriptor( out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc, - make_tuple(make_pass_through_transform(N), - make_slice_transform(ZDot, I0, ZDotSlice), - make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), - make_slice_transform(YDot, I0, YDotSlice), - make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), - make_slice_transform(XDot, I0, XDotSlice), - make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), - make_pass_through_transform(K)), + make_tuple( + make_pass_through_transform(N_), + make_slice_transform(ZDot_, I0, ZDotSlice), + make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, @@ -452,8 +803,9 @@ struct TransformConvBwdDataToGemm_v1 const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor( out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, make_tuple( - make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)), - make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice))), + make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)), + make_merge_transform( + make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice))), make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -482,66 +834,31 @@ struct TransformConvBwdDataToGemm_v1 } } - template || - is_same_v), + (is_same_v || + is_same_v), bool>::type = false> - static auto MakeBDescriptor_BK0_N_BK1( - const std::array& out_g_n_k_wos_lengths, - const std::array& /* out_g_n_k_wos_strides */, - const std::array& wei_g_k_c_xs_lengths, - const std::array& /* wei_g_k_c_xs_strides */, - const std::array& in_g_n_c_wis_lengths, - const std::array& /* in_g_n_c_wis_strides */, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& /* input_left_pads */, - const std::array& /* input_right_pads */, - const std::array& tildes) + __host__ __device__ auto MakeBDescriptor_BK0_N_BK1() const { - index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum]; - index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum]; - index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum]; - - const index_t N = in_g_n_c_wis_lengths[1]; - const index_t K = wei_g_k_c_xs_lengths[1]; - const index_t C = wei_g_k_c_xs_lengths[2]; - - const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1; - const index_t Ho = out_g_n_k_wos_lengths[HIdx]; - const index_t Wo = out_g_n_k_wos_lengths[WIdx]; - - const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1; - const index_t Y = wei_g_k_c_xs_lengths[YIdx]; - const index_t X = wei_g_k_c_xs_lengths[XIdx]; - - const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum]; - const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum]; - const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum]; - - const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum]; - const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; - const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; - // assume packed // k_y_x_c for 2d or k_z_y_x_c for 3d - const auto wei_grid_desc = make_wei_grid_desc(K, Z, Y, X, C); + const auto wei_grid_desc = MakeWeiGridDesc(); if constexpr(ConvBwdDataSpecialization == ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: Filter1x1Stride1Pad0) { - const index_t BK0 = math::integer_divide_ceil(K, BK1); + const index_t BK0 = math::integer_divide_ceil(K_, BK1); // B: weight tensor const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = - transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K_, C_)), make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(C)), + make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, C), make_tuple(I0, I1)); + make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, C_), make_tuple(I0, I1)); const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( @@ -553,22 +870,10 @@ struct TransformConvBwdDataToGemm_v1 } else { - const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); - const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); - const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - const auto ZTilde = ConvStrideD / GcdStrideDilationD; - const auto YTilde = ConvStrideH / GcdStrideDilationH; - const auto XTilde = ConvStrideW / GcdStrideDilationW; - - const auto ZDot = math::integer_divide_ceil(Z, ZTilde); - const auto YDot = math::integer_divide_ceil(Y, YTilde); - const auto XDot = math::integer_divide_ceil(X, XTilde); - // GemmK is different for each GEMM - const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); - const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); - const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + const auto ZDotSlice = math::integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_); + const auto YDotSlice = math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_); + const auto XDotSlice = math::integer_divide_ceil(X_ - IdxXTilde_, XTilde_); // B weight tensor if constexpr(NDimSpatial == 2) @@ -576,23 +881,23 @@ struct TransformConvBwdDataToGemm_v1 const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( wei_grid_desc, make_tuple( - make_pass_through_transform(K), - make_embed_transform(make_tuple(YDot, YTilde), - make_tuple(ConvStrideH / GcdStrideDilationH, I1)), - make_embed_transform(make_tuple(XDot, XTilde), - make_tuple(ConvStrideW / GcdStrideDilationW, I1)), - make_pass_through_transform(C)), + make_pass_through_transform(K_), + make_embed_transform(make_tuple(YDot_, YTilde_), + make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), + make_embed_transform(make_tuple(XDot_, XTilde_), + make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); const auto wei_k_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor( wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, - make_tuple(make_pass_through_transform(K), - make_slice_transform(YDot, I0, YDotSlice), - make_slice_transform(XDot, I0, XDotSlice), - make_freeze_transform(i_ytilde), - make_freeze_transform(i_xtilde), - make_pass_through_transform(C)), + make_tuple(make_pass_through_transform(K_), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_freeze_transform(IdxYTilde_), + make_freeze_transform(IdxXTilde_), + make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, @@ -608,8 +913,8 @@ struct TransformConvBwdDataToGemm_v1 const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor( wei_k_ydotslice_xdotslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)), - make_pass_through_transform(C)), + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), + make_pass_through_transform(C_)), make_tuple(Sequence<1, 2, 0>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -636,15 +941,17 @@ struct TransformConvBwdDataToGemm_v1 const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( wei_grid_desc, - make_tuple( - make_pass_through_transform(K), - make_embed_transform(make_tuple(ZDot, ZTilde), - make_tuple(ConvStrideD / GcdStrideDilationD, I1)), - make_embed_transform(make_tuple(YDot, YTilde), - make_tuple(ConvStrideH / GcdStrideDilationH, I1)), - make_embed_transform(make_tuple(XDot, XTilde), - make_tuple(ConvStrideW / GcdStrideDilationW, I1)), - make_pass_through_transform(C)), + make_tuple(make_pass_through_transform(K_), + make_embed_transform( + make_tuple(ZDot_, ZTilde_), + make_tuple(ConvStrideD_ / GcdStrideDilationD_, I1)), + make_embed_transform( + make_tuple(YDot_, YTilde_), + make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), + make_embed_transform( + make_tuple(XDot_, XTilde_), + make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, @@ -659,14 +966,14 @@ struct TransformConvBwdDataToGemm_v1 const auto wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor( wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc, - make_tuple(make_pass_through_transform(K), - make_slice_transform(ZDot, I0, ZDotSlice), - make_slice_transform(YDot, I0, YDotSlice), - make_slice_transform(XDot, I0, XDotSlice), - make_freeze_transform(i_ztilde), - make_freeze_transform(i_ytilde), - make_freeze_transform(i_xtilde), - make_pass_through_transform(C)), + make_tuple(make_pass_through_transform(K_), + make_slice_transform(ZDot_, I0, ZDotSlice), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_freeze_transform(IdxZTilde_), + make_freeze_transform(IdxYTilde_), + make_freeze_transform(IdxXTilde_), + make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, @@ -686,8 +993,9 @@ struct TransformConvBwdDataToGemm_v1 const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor( wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)), - make_pass_through_transform(C)), + make_tuple( + make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)), + make_pass_through_transform(C_)), make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -716,66 +1024,20 @@ struct TransformConvBwdDataToGemm_v1 } } - template || - is_same_v || - is_same_v || - is_same_v || - is_same_v), - bool>::type = false> - static auto - MakeCDescriptor_M_N(const std::array& out_g_n_k_wos_lengths, - const std::array& /* out_g_n_k_wos_strides */, - const std::array& wei_g_k_c_xs_lengths, - const std::array& /* wei_g_k_c_xs_strides */, - const std::array& in_g_n_c_wis_lengths, - const std::array& in_g_n_c_wis_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads, - const std::array& tildes) + template < + typename CLayout_ = CLayout, + typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) && + (is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const { - index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum]; - index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum]; - index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum]; - - const index_t N = in_g_n_c_wis_lengths[1]; - const index_t C = wei_g_k_c_xs_lengths[2]; - - const index_t Di = NDimSpatial == 3 ? in_g_n_c_wis_lengths[DIdx] : 1; - const index_t Hi = in_g_n_c_wis_lengths[HIdx]; - const index_t Wi = in_g_n_c_wis_lengths[WIdx]; - - const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1; - const index_t Ho = out_g_n_k_wos_lengths[HIdx]; - const index_t Wo = out_g_n_k_wos_lengths[WIdx]; - - const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1; - const index_t Y = wei_g_k_c_xs_lengths[YIdx]; - const index_t X = wei_g_k_c_xs_lengths[XIdx]; - - const index_t InLeftPadD = input_left_pads[DIdx - NonSpatialDimsNum]; - const index_t InLeftPadH = input_left_pads[HIdx - NonSpatialDimsNum]; - const index_t InLeftPadW = input_left_pads[WIdx - NonSpatialDimsNum]; - - const index_t InRightPadD = input_right_pads[DIdx - NonSpatialDimsNum]; - const index_t InRightPadH = input_right_pads[HIdx - NonSpatialDimsNum]; - const index_t InRightPadW = input_right_pads[WIdx - NonSpatialDimsNum]; - - const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum]; - const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum]; - const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum]; - - const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum]; - const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; - const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; - // assume strided // n_hi_wi_c for 2d n_di_hi_wi_c for 3d - const auto in_grid_desc = - make_in_grid_desc(N, Di, Hi, Wi, C, in_g_n_c_wis_strides); + const auto in_grid_desc = MakeInGridDesc(); if constexpr(ConvBwdDataSpecialization == ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: @@ -787,10 +1049,10 @@ struct TransformConvBwdDataToGemm_v1 const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( in_grid_desc, make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), - make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), - make_pass_through_transform(C)), + make_pass_through_transform(N_), + make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)), + make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)), + make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); @@ -798,8 +1060,8 @@ struct TransformConvBwdDataToGemm_v1 in_n_y_ho_x_wo_c_grid_desc, make_tuple(make_freeze_transform(I0), make_freeze_transform(I0), - make_merge_transform(make_tuple(N, Ho, Wo)), - make_pass_through_transform(C)), + make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_pass_through_transform(C_)), make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); @@ -818,11 +1080,11 @@ struct TransformConvBwdDataToGemm_v1 const auto in_n_x_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( in_grid_desc, make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)), - make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), - make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), - make_pass_through_transform(C)), + make_pass_through_transform(N_), + make_embed_transform(make_tuple(I1, Do_), make_tuple(I1, ConvStrideD_)), + make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)), + make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)), + make_pass_through_transform(C_)), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, @@ -836,8 +1098,8 @@ struct TransformConvBwdDataToGemm_v1 make_tuple(make_freeze_transform(I0), make_freeze_transform(I0), make_freeze_transform(I0), - make_merge_transform(make_tuple(N, Do, Ho, Wo)), - make_pass_through_transform(C)), + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_pass_through_transform(C_)), make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, @@ -861,36 +1123,21 @@ struct TransformConvBwdDataToGemm_v1 } else { - const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); - const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); - const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - const auto ZTilde = ConvStrideD / GcdStrideDilationD; - const auto YTilde = ConvStrideH / GcdStrideDilationH; - const auto XTilde = ConvStrideW / GcdStrideDilationW; - - const auto DTilde = - Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD); - const auto HTilde = - Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); - const auto WTilde = - Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); - // only work on DTilde, HTilde and WTilde that contribute to // non-padding area of input tensor const auto IDTildeSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD); + math::max(I0, InLeftPadD_ - ConvDilationD_ * (ZTilde_ - I1)), ConvStrideD_); const auto IHTildeSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + math::max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_); const auto IWTildeSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + math::max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_); const auto IDTildeSliceEnd = math::min( - DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1); + DTilde_, math::integer_divide_ceil(InLeftPadD_ + Di_ - I1, ConvStrideD_) + I1); const auto IHTildeSliceEnd = math::min( - HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + HTilde_, math::integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1); const auto IWTildeSliceEnd = math::min( - WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + WTilde_, math::integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1); const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin; const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; @@ -901,34 +1148,34 @@ struct TransformConvBwdDataToGemm_v1 { const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( in_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(YTilde, HTilde), - make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(XTilde, WTilde), - make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(YTilde_, HTilde_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(XTilde_, WTilde_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple( Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_freeze_transform(i_ytilde), - make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), - make_freeze_transform(i_xtilde), - make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), - make_pass_through_transform(C)), + make_tuple(make_pass_through_transform(N_), + make_freeze_transform(IdxYTilde_), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(IdxXTilde_), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, @@ -944,8 +1191,8 @@ struct TransformConvBwdDataToGemm_v1 const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( in_n_htildeslice_wtildeslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), - make_pass_through_transform(C)), + make_tuple(make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C_)), make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -961,11 +1208,11 @@ struct TransformConvBwdDataToGemm_v1 { const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( in_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Di, InLeftPadD, InRightPadD), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple( @@ -974,14 +1221,14 @@ struct TransformConvBwdDataToGemm_v1 const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( in_n_dip_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(ZTilde, DTilde), - make_tuple(ConvDilationD, ConvStrideD)), - make_embed_transform(make_tuple(YTilde, HTilde), - make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(XTilde, WTilde), - make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(ZTilde_, DTilde_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(YTilde_, HTilde_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(XTilde_, WTilde_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, @@ -996,14 +1243,14 @@ struct TransformConvBwdDataToGemm_v1 const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_freeze_transform(i_ztilde), - make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), - make_freeze_transform(i_ytilde), - make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), - make_freeze_transform(i_xtilde), - make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), - make_pass_through_transform(C)), + make_tuple(make_pass_through_transform(N_), + make_freeze_transform(IdxZTilde_), + make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), + make_freeze_transform(IdxYTilde_), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(IdxXTilde_), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, @@ -1024,8 +1271,8 @@ struct TransformConvBwdDataToGemm_v1 const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc, make_tuple( - make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)), - make_pass_through_transform(C)), + make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C_)), make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -1044,84 +1291,41 @@ struct TransformConvBwdDataToGemm_v1 } // for input bias - template || - is_same_v), + (is_same_v || + is_same_v), bool>::type = false> - static auto - MakeCDescriptor_M_N(const std::array& out_g_n_k_wos_lengths, - const std::array& /* out_g_n_k_wos_strides */, - const std::array& wei_g_k_c_xs_lengths, - const std::array& /* wei_g_k_c_xs_strides */, - const std::array& in_g_n_c_wis_lengths, - const std::array& /* in_g_n_c_wis_strides */, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& /* input_right_pads */, - const std::array& /* tildes */) + __host__ __device__ auto MakeCDescriptor_M_N() const { - const index_t N = in_g_n_c_wis_lengths[1]; - const index_t C = wei_g_k_c_xs_lengths[2]; - - const index_t Hi = in_g_n_c_wis_lengths[3]; - const index_t Wi = in_g_n_c_wis_lengths[4]; - - const index_t Ho = out_g_n_k_wos_lengths[3]; - const index_t Wo = out_g_n_k_wos_lengths[4]; - - const index_t Y = wei_g_k_c_xs_lengths[3]; - const index_t X = wei_g_k_c_xs_lengths[4]; - - const index_t InLeftPadH = input_left_pads[0]; - const index_t InLeftPadW = input_left_pads[1]; - - const index_t ConvStrideH = conv_filter_strides[0]; - const index_t ConvStrideW = conv_filter_strides[1]; - - const index_t ConvDilationH = conv_filter_dilations[0]; - const index_t ConvDilationW = conv_filter_dilations[1]; - if constexpr(ConvBwdDataSpecialization == ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: Filter1x1Stride1Pad0) { const auto in_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, C), make_tuple(I0, I1)); + make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, C_), make_tuple(I0, I1)); return in_gemmm_gemmn_grid_desc; } else { - const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); - const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - const auto YTilde = ConvStrideH / GcdStrideDilationH; - const auto XTilde = ConvStrideW / GcdStrideDilationW; - - const auto HTilde = - Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); - const auto WTilde = - Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); - // only work on HTilde and WTilde that contribute to non-padding area of input tensor const auto IHTildeSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + math::max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_); const auto IWTildeSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + math::max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_); const auto IHTildeSliceEnd = math::min( - HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + HTilde_, math::integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1); const auto IWTildeSliceEnd = math::min( - WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + WTilde_, math::integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1); const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; // bias tensor const auto in_gemmmraw_gemmnraw_grid_desc = make_naive_tensor_descriptor( - make_tuple(N * HTildeSlice * WTildeSlice, C), make_tuple(I0, I1)); + make_tuple(N_ * HTildeSlice * WTildeSlice, C_), make_tuple(I0, I1)); const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( in_gemmmraw_gemmnraw_grid_desc, @@ -1131,6 +1335,25 @@ struct TransformConvBwdDataToGemm_v1 return in_gemmm_gemmn_grid_desc; } } + + IndexType N_; + IndexType Di_, Hi_, Wi_; + IndexType Do_, Ho_, Wo_; + IndexType Z_, Y_, X_; + IndexType K_, C_; + IndexType DiStride_, HiStride_, WiStride_; + IndexType DoStride_, HoStride_, WoStride_; + IndexType CStrideTensorB_, CStrideTensorC_, KStrideTensorA_, KStrideTensorB_; + IndexType NStrideTensorA_, NStrideTensorC_; + IndexType ConvStrideD_, ConvStrideH_, ConvStrideW_; + IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_; + IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_; + IndexType InRightPadD_, InRightPadH_, InRightPadW_; + IndexType IdxZTilde_, IdxYTilde_, IdxXTilde_; + IndexType GcdStrideDilationD_, GcdStrideDilationH_, GcdStrideDilationW_; + IndexType ZTilde_, YTilde_, XTilde_; + IndexType DTilde_, HTilde_, WTilde_; + IndexType ZDot_, YDot_, XDot_; }; } // namespace tensor_operation diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index b91b12ad52380a82ac6213cea27016700aca1461..3db94deccb465da766483521bf7bef7ca332a02c 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -1,10 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include "ck/library/utility/numeric.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" @@ -148,8 +147,8 @@ struct TransformConvFwdToGemm template ::type = false> + index_t NDim = NDimSpatial, + typename ck::enable_if::type = false> __host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, const ConvDimsType& a_g_n_c_wis_strides, const ConvDimsType& b_g_k_c_xs_lengths, @@ -201,11 +200,15 @@ struct TransformConvFwdToGemm InRightPadW_{input_right_pads[I0]}, ZYX_{X_} { +#ifdef CK_CODE_GEN_RTC + static_assert(is_same_v>); + static_assert(is_same_v>); +#else static_assert(is_same_v> || is_same_v>); static_assert(is_same_v> || is_same_v>); - +#endif if constexpr(SplitN) { N_ = GetSplitedNSize( @@ -219,8 +222,8 @@ struct TransformConvFwdToGemm template ::type = false> + index_t NDim = NDimSpatial, + typename ck::enable_if::type = false> __host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, const ConvDimsType& a_g_n_c_wis_strides, const ConvDimsType& b_g_k_c_xs_lengths, @@ -272,11 +275,15 @@ struct TransformConvFwdToGemm InRightPadW_{input_right_pads[I1]}, ZYX_{Y_ * X_} { +#ifdef CK_CODE_GEN_RTC + static_assert(is_same_v>); + static_assert(is_same_v>); +#else static_assert(is_same_v> || is_same_v>); static_assert(is_same_v> || is_same_v>); - +#endif if constexpr(SplitN) { N_ = GetSplitedNSize( @@ -290,8 +297,8 @@ struct TransformConvFwdToGemm template ::type = false> + index_t NDim = NDimSpatial, + typename ck::enable_if::type = false> __host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, const ConvDimsType& a_g_n_c_wis_strides, const ConvDimsType& b_g_k_c_xs_lengths, @@ -343,11 +350,15 @@ struct TransformConvFwdToGemm InRightPadW_{input_right_pads[I2]}, ZYX_{Z_ * Y_ * X_} { +#ifdef CK_CODE_GEN_RTC + static_assert(is_same_v>); + static_assert(is_same_v>); +#else static_assert(is_same_v> || is_same_v>); static_assert(is_same_v> || is_same_v>); - +#endif if constexpr(SplitN) { N_ = GetSplitedNSize( @@ -478,11 +489,11 @@ struct TransformConvFwdToGemm // TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as // properties template || - is_same_v || - is_same_v), - bool>::type = false> + typename ck::enable_if || + is_same_v || + is_same_v), + bool>::type = false> __host__ __device__ auto MakeADescriptor_M_K() const { if constexpr(ConvForwardSpecialization == @@ -691,11 +702,11 @@ struct TransformConvFwdToGemm } template || - is_same_v || - is_same_v), - bool>::type = false> + typename ck::enable_if || + is_same_v || + is_same_v), + bool>::type = false> __host__ __device__ auto MakeADescriptor_M_K() const { @@ -932,7 +943,7 @@ struct TransformConvFwdToGemm } template || is_same_v || is_same_v), @@ -1242,19 +1253,19 @@ struct TransformConvFwdToGemm } template || - is_same_v || - is_same_v, - bool>::type = false> + typename ck::enable_if || + is_same_v || + is_same_v, + bool>::type = false> __host__ __device__ auto MakeBDescriptor_N_K() const { if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter3x3) { using FilterSizeNumType = - std::conditional_t, - std::conditional_t, Number<27>>>; + ck::conditional_t, + ck::conditional_t, Number<27>>>; if constexpr(NumGroupsToMerge == 1) { @@ -1297,13 +1308,13 @@ struct TransformConvFwdToGemm template < typename BLayout, - typename std::enable_if || - is_same_v || - is_same_v || - is_same_v || - is_same_v || - is_same_v, - bool>::type = false> + typename ck::enable_if || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v, + bool>::type = false> __host__ __device__ auto MakeBDescriptor_N_K() const { const auto wei_k_yx_c_desc = make_naive_tensor_descriptor( @@ -1318,36 +1329,36 @@ struct TransformConvFwdToGemm return wei_gemmn_gemmk_desc; } - template ), - bool>::type = false> + typename ck::enable_if), + bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { return make_naive_tensor_descriptor(make_tuple(N_ * Wo_, K_), make_tuple(I0, KStrideTensorC_)); } - template ), - bool>::type = false> + typename ck::enable_if), + bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), make_tuple(I0, KStrideTensorC_)); } - template ), - bool>::type = false> + typename ck::enable_if), + bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_), @@ -1355,12 +1366,12 @@ struct TransformConvFwdToGemm } template || - is_same_v || - is_same_v), - bool>::type = false> + index_t NDimSp = NDimSpatial, + typename ck::enable_if || + is_same_v || + is_same_v), + bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { const IndexType NDoHoWo = N_ * Wo_; @@ -1410,11 +1421,11 @@ struct TransformConvFwdToGemm template || - is_same_v || - is_same_v), - bool>::type = false> + typename ck::enable_if || + is_same_v || + is_same_v), + bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { const IndexType NDoHoWo = N_ * Ho_ * Wo_; @@ -1467,7 +1478,7 @@ struct TransformConvFwdToGemm template || is_same_v || is_same_v), diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index d4ee5c886cd416e02de46a83213fc0bdc4e621ad..328e37d00971eee2ee50270a320673f0e1f88a9d 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "data_type.hpp" @@ -429,7 +429,8 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); using r_t = typename vector_type::type; @@ -549,8 +550,10 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); @@ -578,7 +581,7 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type::typ tmp.template AsType()[i]); }); } -#if defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) else if constexpr(is_same::value) { vector_type tmp{src_thread_data}; @@ -843,8 +846,8 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, #else - vector_t tmp = amd_buffer_load_impl( - src_wave_buffer_resource, src_thread_addr_offset, 0); + vector_t tmp{amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0)}; return src_thread_element_valid ? tmp : vector_t(0); #endif } @@ -873,8 +876,8 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, constexpr index_t vector_size = scalar_type::vector_size; - vector_t tmp = amd_buffer_load_impl( - src_wave_buffer_resource, src_thread_addr_offset, 0); + vector_t tmp{amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0)}; return src_thread_element_valid ? tmp : vector_t(customized_value); } @@ -1018,15 +1021,24 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; static_assert(bytes_per_thread == dword_bytes); +#ifndef CK_CODE_GEN_RTC const uint32_t* global_ptr = reinterpret_cast(reinterpret_cast(global_base_ptr)); +#else + const uint32_t* global_ptr = + reinterpret_cast(reinterpret_cast(global_base_ptr)); +#endif const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size); const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000; #if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM T* lds_ptr = lds_base_ptr + lds_offset; +#ifndef CK_CODE_GEN_RTC auto const lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast(lds_ptr))); +#else + auto const lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast(lds_ptr))); +#endif asm volatile("s_mov_b32 m0, %0; \n\t" "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), "v"(global_offset_bytes), @@ -1035,8 +1047,13 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, #else // LDS pointer must be attributed with the LDS address space. __attribute__((address_space(3))) uint32_t* lds_ptr = +#ifndef CK_CODE_GEN_RTC reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( reinterpret_cast(lds_base_ptr + lds_offset)); +#else + reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( + reinterpret_cast(lds_base_ptr + lds_offset)); +#endif llvm_amdgcn_raw_buffer_load_lds( src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp new file mode 100644 index 0000000000000000000000000000000000000000..42b784d303766ccf3e3dd1ba0d7ee296f34f3d85 --- /dev/null +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -0,0 +1,996 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/enable_if.hpp" +#include "ck/utility/random_gen.hpp" +#include "ck/utility/type.hpp" + +#ifdef CK_USE_FNUZ_FP8 +#define CK_USE_FNUZ_FP8 1 +#else +#define CK_USE_FNUZ_FP8 0 +#endif + +#ifdef CK_USE_OCP_FP8 +#define CK_USE_OCP_FP8 1 +#else +#define CK_USE_OCP_FP8 0 +#endif + +#if(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx1200__) || \ + defined(__gfx1201__) || defined(__gfx950__)) && \ + __HIP_DEVICE_COMPILE__ +#define CK_FP8_CVT_FAST_PATH 1 +#else +#define CK_FP8_CVT_FAST_PATH 0 +#endif + +#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__ +#define CK_OCP_FP8_CVT_FAST_PATH 1 +#else +#define CK_OCP_FP8_CVT_FAST_PATH 0 +#endif + +namespace ck { + +using f8_fnuz_t = _BitInt(8); +using bf8_fnuz_t = unsigned _BitInt(8); + +typedef unsigned char fp8_storage_t; + +/** + * \brief Describes FP8 interpretation + */ +enum class ck_fp8_interpretation_t +{ + CK_E4M3_OCP = 0, // OCP E4M3 + CK_E5M2_OCP = 1, // OCP E5M2 + CK_E4M3_FNUZ = 2, // FP8 + CK_E5M2_FNUZ = 3, // BF8 +}; + +/** + * \brief Describes saturation behavior + */ +enum class ck_saturation_t +{ + CK_NOSAT = 0, // No saturation - replace with NaN or Inf + CK_SATFINITE = 1, // Saturate to finite +}; + +namespace fp8_impl { + +typedef fp8_storage_t fp8x2_storage_t __attribute__((ext_vector_type(2))); +typedef float float2_t __attribute__((ext_vector_type(2))); + +__host__ __device__ static inline constexpr bool fnuz_f8_is_nan(f8_fnuz_t a) +{ + return static_cast(a) == 0x80; +} +__host__ __device__ static inline constexpr bool fnuz_bf8_is_nan(bf8_fnuz_t a) +{ + return static_cast(a) == 0x80; +} + +__host__ __device__ static inline constexpr bool ocp_f8_is_nan(fp8_storage_t a) +{ + return (a & 0x7f) == 0x7f; +} +__host__ __device__ static inline constexpr bool ocp_bf8_is_nan(fp8_storage_t a) +{ + return (a & 0x7f) > 0x7c; +} + +// The conversion function is from rocblas +// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220 +// This has been modified to handle double types as well +template +__host__ __device__ static inline T cast_from_f8(fp8_storage_t x) +{ + constexpr bool is_half = __hip_internal::is_same::value; + constexpr bool is_float = __hip_internal::is_same::value; + constexpr bool is_double = __hip_internal::is_same::value; + static_assert(is_half || is_float || is_double, "only half, float and double are supported"); + + constexpr int weo = is_half ? 5 : (is_float ? 8 : 11); + constexpr int wmo = is_half ? 10 : (is_float ? 23 : 52); + + T fInf, fNegInf, fNaN, fNeg0, fmax, fmin; + if constexpr(is_half) + { + const unsigned short int ihInf = 0x7C00; + const unsigned short int ihNegInf = 0xFC00; + const unsigned short int ihNaN = 0x7C01; + const unsigned short int ihNeg0 = 0x8000; + /* Max number in e5m2 57344*/ + const unsigned short int ifmax = 0x7B00; + const unsigned short int ifmin = 0xFB00; + + fInf = bit_cast<_Float16>(ihInf); + fNegInf = bit_cast<_Float16>(ihNegInf); + fNaN = bit_cast<_Float16>(ihNaN); + fNeg0 = bit_cast<_Float16>(ihNeg0); + fmax = bit_cast<_Float16>(ifmax); + fmin = bit_cast<_Float16>(ifmin); + } + else if constexpr(is_float) + { + const unsigned int ifInf = 0x7F800000; + const unsigned int ifNegInf = 0xFF800000; + const unsigned int ifNaN = 0x7F800001; + const unsigned int ifNeg0 = 0x80000000; + /* Max number in e5m2 57344*/ + const unsigned int ifmax = 0x47600000; + const unsigned int ifmin = 0xC7600000; + + fInf = bit_cast(ifInf); + fNegInf = bit_cast(ifNegInf); + fNaN = bit_cast(ifNaN); + fNeg0 = bit_cast(ifNeg0); + fmax = bit_cast(ifmax); + fmin = bit_cast(ifmin); + } + else if constexpr(is_double) + { + const unsigned long long ifInf = 0x7FF0000000000000ull; + const unsigned long long ifNegInf = 0xFFF0000000000000ull; + const unsigned long long ifNaN = 0x7FF0000000000001ull; + const unsigned long long ifNeg0 = 0x8000000000000000ull; + /* Max number in e5m2 57344*/ + const unsigned long long ifmax = 0x40EC000000000000ull; + const unsigned long long ifmin = 0xC0EC000000000000ull; + + fInf = bit_cast(ifInf); + fNegInf = bit_cast(ifNegInf); + fNaN = bit_cast(ifNaN); + fNeg0 = bit_cast(ifNeg0); + fmax = bit_cast(ifmax); + fmin = bit_cast(ifmin); + } + + if(x == 0) + { + return 0; + } + + unsigned long long sign = x >> 7; + unsigned long long mantissa = x & ((1 << wm) - 1); + int exponent = (x & 0x7F) >> wm; + if constexpr(is_fnuz) + { + if(x == 0x80) + { + return fNaN; + } + } + else + { + if(x == 0x80) + { + return fNeg0; + } + if constexpr(we == 4) + { // e4m3 + if((x & 0x7F) == 0x7F) + { + return fNaN; + } + } + else if((x & 0x7C) == 0x7C) + { // e5m2 + if((x & 0x3) == 0) + { + if constexpr(clip) + { + return sign ? fmin : fmax; + } + return sign ? fNegInf : fInf; + } + return fNaN; + } + } + + typename std::conditional< + sizeof(T) == 2, + unsigned short int, + typename std::conditional::type>::type + retval; + + if constexpr(we == 5 && is_half && !is_fnuz) + { + retval = x << 8; + return bit_cast(retval); + } + + const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (is_fnuz ? 1 : 0); + + // subnormal input + if(exponent == 0) + { +#if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__ + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above + int sh = 1 + __clz(mantissa) - (32 - wm); +#else + int sh = 1 + __builtin_clz(mantissa) - (32 - wm); +#endif + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1ull << wm) - 1); + } + exponent += exp_low_cutoff - 1; + mantissa <<= wmo - wm; + + // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) + if(exponent <= 0) + { + mantissa |= 1 << wmo; + mantissa >>= 1 - exponent; + exponent = 0; + } + + if constexpr(sizeof(T) == 2) + retval = (sign << 15) | (exponent << 10) | mantissa; + else if constexpr(sizeof(T) == 4) + retval = (sign << 31) | (exponent << 23) | mantissa; + else + retval = (sign << 63) | (static_cast(exponent) << 52) | mantissa; + + return bit_cast(retval); +} + +#if CK_FP8_CVT_FAST_PATH +template +static __device__ float cast_to_f32_from_f8(fp8_storage_t v) +{ + union + { + unsigned int i32val; + unsigned char i8val[4]; + } val; + val.i8val[0] = v; + + static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ || + interpret == ck_fp8_interpretation_t::CK_E4M3_OCP || + interpret == ck_fp8_interpretation_t::CK_E5M2_FNUZ || + interpret == ck_fp8_interpretation_t::CK_E5M2_OCP, + "Only FNUZ and OCP interpretations are supported"); + + if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) || + (interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)) + { + return __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0); + } + else + { + return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0); + } +} + +template +static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v) +{ + const auto i16val = bit_cast(v); + + static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ || + interpret == ck_fp8_interpretation_t::CK_E4M3_OCP || + interpret == ck_fp8_interpretation_t::CK_E5M2_FNUZ || + interpret == ck_fp8_interpretation_t::CK_E5M2_OCP, + "Only FNUZ and OCP interpretations are supported"); + + if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) || + (interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)) + { + return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, false); + } + else + { + return __builtin_amdgcn_cvt_pk_f32_bf8(i16val, false); + } +} +#endif + +} // namespace fp8_impl + +struct f8_ocp_t +{ + using data_type = fp8_storage_t; + data_type data; + + static constexpr ck_saturation_t default_saturation = ck_saturation_t::CK_SATFINITE; + static constexpr ck_fp8_interpretation_t default_interpret = + ck_fp8_interpretation_t::CK_E4M3_OCP; + + static constexpr unsigned int we = 4; // exponent width + static constexpr unsigned int wm = 3; // mantissa width + + __host__ __device__ constexpr bool operator==(const f8_ocp_t& other) const + { + return (data == other.data) && (fp8_impl::ocp_f8_is_nan(data) == false); // NaN != NaN + } + +#if CK_USE_OCP_FP8 + __host__ __device__ explicit operator float() const +#else + __host__ explicit operator float() const +#endif + { +#if CK_OCP_FP8_CVT_FAST_PATH + return fp8_impl::cast_to_f32_from_f8(this->data); +#else + return fp8_impl::cast_from_f8( + this->data); // XXX: clip==false must be consistent with operator _Float16 +#endif + } + +#if CK_USE_OCP_FP8 + __host__ __device__ explicit operator _Float16() const +#else + __host__ explicit operator _Float16() const +#endif + { +#if CK_OCP_FP8_CVT_FAST_PATH + return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8(this->data)); +#else + return fp8_impl::cast_from_f8<_Float16, wm, we, false>( + this->data); // XXX: clip==false must be consistent with operator float +#endif + } +}; + +struct bf8_ocp_t +{ + using data_type = fp8_storage_t; + data_type data; + + static constexpr ck_saturation_t default_saturation = ck_saturation_t::CK_SATFINITE; + static constexpr ck_fp8_interpretation_t default_interpret = + ck_fp8_interpretation_t::CK_E5M2_OCP; + + static constexpr unsigned int we = 5; // exponent width + static constexpr unsigned int wm = 2; // mantissa width + + __host__ __device__ constexpr bool operator==(const bf8_ocp_t& other) const + { + return (data == other.data) && (fp8_impl::ocp_bf8_is_nan(data) == false); // NaN != NaN + } + +#if CK_USE_OCP_FP8 + __host__ __device__ explicit operator float() const + +#else + __host__ explicit operator float() const +#endif + { +#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) + return fp8_impl::cast_to_f32_from_f8(this->data); +#else + return fp8_impl::cast_from_f8( + this->data); // XXX: clip==false must be consistent with operator _Float16 +#endif + } + +#if CK_USE_OCP_FP8 + __host__ __device__ explicit operator _Float16() const +#else + __host__ explicit operator _Float16() const +#endif + { +#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) + return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8(this->data)); +#else + return fp8_impl::cast_from_f8<_Float16, wm, we, false>( + this->data); // XXX: clip==false must be consistent with operator float +#endif + } +}; + +template +__host__ __device__ static inline constexpr bool fp8_is_nan(T); + +template <> +__host__ __device__ inline constexpr bool fp8_is_nan(f8_ocp_t a) +{ + return fp8_impl::ocp_f8_is_nan(a.data); +} +template <> +__host__ __device__ inline constexpr bool fp8_is_nan(bf8_ocp_t a) +{ + return fp8_impl::ocp_bf8_is_nan(a.data); +} +template <> +__host__ __device__ inline constexpr bool fp8_is_nan(f8_fnuz_t a) +{ + return fp8_impl::fnuz_f8_is_nan(a); +} +template <> +__host__ __device__ inline constexpr bool fp8_is_nan(bf8_fnuz_t a) +{ + return fp8_impl::fnuz_bf8_is_nan(a); +} + +template || is_same_v || + is_same_v || is_same_v, + bool> = true> +__host__ __device__ static inline constexpr bool fp8_is_inf(T) +{ + return false; +} +template <> +__host__ __device__ inline constexpr bool fp8_is_inf(bf8_ocp_t a) +{ + return (a.data & 0x7f) == 0x7c; +} + +namespace fp8_impl { + +// Assertions to check for supported conversion types +#define __assert_ocp_support(interp) \ + { \ + if(interp != ck_fp8_interpretation_t::CK_E4M3_OCP && \ + interp != ck_fp8_interpretation_t::CK_E5M2_OCP) \ + { \ + __hip_assert(false && "type is unsupported by current target device"); \ + } \ + } +#define __assert_fnuz_support(interp) \ + { \ + if(interp != ck_fp8_interpretation_t::CK_E4M3_FNUZ && \ + interp != ck_fp8_interpretation_t::CK_E5M2_FNUZ) \ + { \ + __hip_assert(false && "type is unsupported by current target device"); \ + } \ + } + +__host__ __device__ static inline void +__is_interpret_supported([[maybe_unused]] ck_fp8_interpretation_t interp) +{ +#if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__ +#if CK_USE_OCP_FP8 + __assert_ocp_support(interp); +#endif +#if CK_USE_FNUZ_FP8 + __assert_fnuz_support(interp); +#endif +#endif +} + +#if CK_FP8_CVT_FAST_PATH +// The conversion function is from rocblas +// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79 +template +static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng = 0) +{ + fp8_storage_t i8data; + union + { + float fval; + unsigned int i32val; + unsigned char i8val[4]; // NOTE: not endian independent + } val; + + unsigned int ival = 0; + val.fval = v; + + if constexpr(saturate) + { + if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) + { + if((val.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); + } + } + else if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP) + { // OCP type + if((val.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0); + } + } + else + { + if((val.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0); + } + } + } + + if constexpr(stochastic_rounding) + { + ival = (interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) || + (interpret == ck_fp8_interpretation_t::CK_E4M3_OCP) + ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0) + : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos + val.i32val = ival; + i8data = val.i8val[0]; // little endian + } + else + { // RNE CVT + ival = (interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) || + (interpret == ck_fp8_interpretation_t::CK_E4M3_OCP) + ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false) + : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, + val.fval, + ival, + false); // false -> WORD0 + val.i32val = ival; + i8data = val.i8val[0]; + } + return i8data; +} +#endif // CK_FP8_CVT_FAST_PATH + +// The conversion function is from rocblas +// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L39 +// This has been modified to add double types conversion as well +template +__host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rng = 0) +{ + constexpr bool is_half = __hip_internal::is_same::value; + constexpr bool is_float = __hip_internal::is_same::value; + constexpr bool is_double = __hip_internal::is_same::value; + static_assert(is_half || is_float || is_double, + "Only half, float and double can be cast to f8"); + + constexpr int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10); + + using T_bitwise = typename std::conditional< + sizeof(T) == 2, + unsigned short int, + typename std::conditional::type>::type; + T_bitwise x_bitwise = bit_cast(_x); + + unsigned long long x{x_bitwise}; + + unsigned long long head, mantissa; + int exponent, bias; + unsigned int sign; + unsigned long long fInf, mask; + + if constexpr(sizeof(T) == 8) + { + head = x & 0xFFF0000000000000ull; + mantissa = x & 0xFFFFFFFFFFFFFull; + exponent = (head >> 52) & 0x7FF; + sign = head >> 63; + bias = 1023; + fInf = 0x7FF0000000000000ull; + mask = 0x7FFFFFFFFFFFFFFFull; + } + else if constexpr(sizeof(T) == 4) + { + head = x & 0xFF800000; + mantissa = x & 0x7FFFFF; + exponent = (head >> 23) & 0xFF; + sign = head >> 31; + bias = 127; + fInf = 0x7F800000; + mask = 0x7FFFFFFF; + } + else + { + head = x & 0xFC00; + mantissa = x & 0x3FF; + exponent = (head >> 10) & 0x1F; + sign = head >> 15; + bias = 15; + fInf = 0x7C00; + mask = 0x7FFF; + } + unsigned int signed_inf = 0; + unsigned int nan = 0; + if constexpr(is_fnuz) + { + signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80; + nan = 0x80; + } + else + { + if constexpr(we == 4) + { // e4m3 + signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f); + } + else + { // e5m2 + signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c); + } + nan = (sign << 7) + 0x7f; + } + // Max values + unsigned long long ifmax = 0; + if constexpr(sizeof(T) == 8) + { + if constexpr(we == 5) + { // 57344 + ifmax = 0x40EC000000000000ull; + } + else + { + if constexpr(is_fnuz) + { // 240 + ifmax = 0x406E000000000000ull; + } + else + { // 448 + ifmax = 0x407C000000000000ull; + } + } + } + else if(sizeof(T) == 4) + { + if constexpr(we == 5) + { + ifmax = 0x47600000; + } + else + { + if constexpr(is_fnuz) + { + ifmax = 0x43700000; + } + else + { + ifmax = 0x43E00000; + } + } + } + else + { + if constexpr(we == 5) + { + ifmax = 0x7B00; + } + else + { + if constexpr(is_fnuz) + { + ifmax = 0x5B80; + } + else + { + ifmax = 0x5F00; + } + } + } + // Deal with inf and NaNs + if((x & fInf) == fInf) + { + if constexpr(is_fnuz) + return signed_inf; + + return mantissa != 0 ? nan : signed_inf; + } + + if((x & mask) > ifmax) + { + return signed_inf; + } + + if(x == 0) + { + return 0; + } + + // First need to check if it is normal or denorm as there is a difference of + // implicit 1 Then need to adjust the exponent to align with the F8 exponent, + // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng + // to mantissa and truncate. And for RNE, no need to add rng. Then probably + // need to check whether there is carry and adjust exponent and mantissa again + + // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent + // bits + const int f8_bias = (1 << (we - 1)) - 1 + (is_fnuz ? 1 : 0); + const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal + // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + // f8_exponent is the converted f8 exponent with bias encoding + // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + // the difference needs to be adjusted and mantissa shifted + int act_exponent, f8_exponent, exponent_diff; + + if(exponent == 0) + { // fp32/fp16 is in denormal. + /* fp32 denormal is below 2^-127 so it is usually not a concern here, we + mostly concern fp16 here. In this case, f8 is usually in denormal. But there + could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has + exponent bias 16. It means that there are some numbers in fp16 denormal but they + are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers + where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 + (NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */ + act_exponent = exponent - bias + 1; + exponent_diff = f8_denormal_act_exponent - + act_exponent; // actual exponent is exponent-bias+1 as it is denormal + } + else + { // fp32/fp16 is normal with implicit 1 + act_exponent = exponent - bias; + if(act_exponent <= f8_denormal_act_exponent) + { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal + range. For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16 + actual exponent is -7, it is actually larger due to the implicit 1, + Therefore it needs to be adjust to -6 and mantissa shift right by 1. + So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + exponent_diff = f8_denormal_act_exponent - act_exponent; + } + else + { // both fp32/fp16 and f8 are in normal range + exponent_diff = 0; // exponent_diff=0 does not mean there is no difference + // for this case, act_exponent could be larger. Just + // that it does not need shift mantissa + } + mantissa += (1ull << mfmt); // Add the implicit 1 into mantissa + } + + bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) == + (1ull << (mfmt - wm + exponent_diff - 1)); + /* This part is a bit tricky. The judgment of whether it is a tie needs to be + done before we shift right as shift right could rip off some residual part and + make something not midpoint look like midpoint. For example, the fp16 number + 0x1002 (0 00100 0000000010), it is larger than midpoint, but after shift right + by 4 bits, it would look like midpoint. + */ + + if(exponent_diff > 0) + mantissa >>= exponent_diff; + else if(exponent_diff == -1) + mantissa <<= -exponent_diff; + bool implicit_one = mantissa & (1ull << mfmt); + // if there is no implicit 1, it means the f8 is denormal and need to adjust + // to denorm exponent + f8_exponent = + (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1); + + // Now we have the exponent and mantissa adjusted + unsigned long long drop_mask = (1ull << (mfmt - wm)) - 1; + bool odd = + mantissa & (1ull << (mfmt - wm)); // if the least significant bit that is not truncated is 1 + mantissa += + (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask; + + // Now we deal with overflow + if(f8_exponent == 0) + { + if((1ull << mfmt) & mantissa) + { + f8_exponent = 1; // denormal overflow to become normal, promote exponent + } + } + else + { + if((1ull << (mfmt + 1)) & mantissa) + { + mantissa >>= 1; + f8_exponent++; + } + } + + mantissa >>= (mfmt - wm); + + // above range: quantize to maximum possible float of the same sign + const int max_exp = (1 << we) - 1; + if(f8_exponent > max_exp) + { + if constexpr(clip) + { + mantissa = (1 << wm) - 1; + f8_exponent = max_exp; + } + else + { + return signed_inf; + } + } + + if(f8_exponent == 0 && mantissa == 0) + return is_fnuz ? 0 : (sign << 7); + mantissa &= (1 << wm) - 1; + return (sign << 7) | (f8_exponent << wm) | mantissa; +} + +/** + * \brief convert float to @p fp8_storage_t + * + * \tparam interp interpretation of fp8 + * \tparam sat saturation of fp8 + * \param f float number + * \return fp8_storage_t + */ +template +#if CK_FP8_CVT_FAST_PATH +__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f) +{ + __is_interpret_supported(interp); + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&f), f); +#else + rng = prand_generator(reinterpret_cast(&f), f); +#endif + } + return cast_to_f8_from_f32( + f, rng); +#else +#if CK_USE_OCP_FP8 +__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f) +{ +#else +__host__ static inline fp8_storage_t cvt_float_to_fp8(const float f) +{ +#endif + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&f), f); +#else + rng = prand_generator(reinterpret_cast(&f), f); +#endif + } + + if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_FNUZ) + { + return cast_to_f8(f, rng); + } + else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_FNUZ) + { + return cast_to_f8(f, rng); + } + else if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP) + { + return cast_to_f8(f, rng); + } + else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP) + { + return cast_to_f8(f, rng); + } + else + { + __hip_assert(false && "FP8 type is not supported by current target device"); + return 0; + } +#endif // CK_FP8_CVT_FAST_PATH +} + +/** + * \brief convert _Float16 to @p fp8_storage_t + * + * \tparam sat saturation of fp8 + * \tparam interp interpretation of fp8 + * \tparam stochastic_rounding switch between RNE and SR + * \param x _Float16 value + * \return fp8_storage_t + */ +template +#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8 +__host__ __device__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x) +#else +__host__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x) +#endif +{ + return cvt_float_to_fp8(static_cast(x)); +} + +} // namespace fp8_impl + +// Declare a template function for fp8 conversion using RNE +template +__host__ __device__ constexpr Y f8_convert_rne(X x); + +// convert fp32 to fp8 with rounding to nearest even +template <> +inline __host__ __device__ f8_ocp_t f8_convert_rne(float x) +{ + return f8_ocp_t{ + fp8_impl::cvt_float_to_fp8(x)}; +} + +// convert fp32 to bf8 with rounding to nearest even +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_rne(float x) +{ + return bf8_ocp_t{ + fp8_impl::cvt_float_to_fp8(x)}; +} + +// convert _Float16 to fp8 with rounding to nearest even +template <> +inline __host__ __device__ f8_ocp_t f8_convert_rne(_Float16 x) +{ + return f8_ocp_t{ + fp8_impl::cvt_half_t_to_fp8(x)}; +} + +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_rne(_Float16 x) +{ + return bf8_ocp_t{ + fp8_impl::cvt_half_t_to_fp8( + x)}; +} + +// Declare a template function for fp8 conversion using RNE +template +__host__ __device__ constexpr Y f8_convert_sr(X x); + +// convert fp32 to fp8 with stochastic rounding +template <> +inline __host__ __device__ f8_ocp_t f8_convert_sr(float x) +{ + return f8_ocp_t{ + fp8_impl::cvt_float_to_fp8( + x)}; +} + +// convert fp32 to bf8 with stochastic rounding +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_sr(float x) +{ + return bf8_ocp_t{fp8_impl::cvt_float_to_fp8(x)}; +} + +// convert _Float16 to fp8 with stochastic rounding +template <> +inline __host__ __device__ f8_ocp_t f8_convert_sr(_Float16 x) +{ + return f8_ocp_t{fp8_impl::cvt_half_t_to_fp8(x)}; +} + +// convert _Float16 to bf8 with stochastic rounding +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_sr(_Float16 x) +{ + return bf8_ocp_t{fp8_impl::cvt_half_t_to_fp8(x)}; +} + +#if CK_USE_OCP_FP8 +using f8_t = f8_ocp_t; +using bf8_t = bf8_ocp_t; +#define CK_FP8_TYPE_FNUZ 0 +#define CK_FP8_TYPE_OCP 1 +#else +using f8_t = f8_fnuz_t; +using bf8_t = bf8_fnuz_t; +#define CK_FP8_TYPE_FNUZ 1 +#define CK_FP8_TYPE_OCP 0 +#endif + +} // namespace ck diff --git a/include/ck/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp index 5dc67a5aded4af289d0240394f720af62e699eb4..113f3af4ae51adbb17512f10a6cdec55d535c40a 100644 --- a/include/ck/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -4,13 +4,34 @@ #ifndef CK_AMD_INLINE_ASM_HPP #define CK_AMD_INLINE_ASM_HPP -#include "data_type.hpp" #include "c_style_pointer_cast.hpp" +#include "data_type.hpp" // TODO: deprecate all amd_assembly_outer_product_xxx namespace ck { +inline __device__ int amd_assembly_and_or_b32(int a, int b, int d) +{ + int c; + asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(c) : "v"(a), "v"(b), "v"(d)); + return c; +} + +inline __device__ half2_t amd_assembly_pk_fma_f16(half2_t a, half2_t b, half2_t c) +{ + half2_t d; + asm volatile("v_pk_fma_f16 %0, %1, %2, %3" : "=v"(d) : "v"(a), "v"(b), "v"(c)); + return d; +} + +inline __device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b) +{ + half2_t c; + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b)); + return c; +} + // c0 += inner_product(a, b0) // c1 += inner_product(a, b1) __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) diff --git a/include/ck/utility/amd_wave_read_first_lane.hpp b/include/ck/utility/amd_wave_read_first_lane.hpp index d6e1eab314e30184c669abe88f5a4cf7f5ea90c4..128c8e9a2c50ba9dc9d123c1e3dc4c036f39e872 100644 --- a/include/ck/utility/amd_wave_read_first_lane.hpp +++ b/include/ck/utility/amd_wave_read_first_lane.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,10 +7,12 @@ #include "ck/utility/functional2.hpp" #include "ck/utility/math.hpp" +#ifndef CK_CODE_GEN_RTC #include #include #include #include +#endif namespace ck { namespace detail { @@ -37,7 +39,7 @@ struct get_carrier<3> { using value_type = uint32_t; - std::array bytes; + Array bytes; static_assert(sizeof(bytes) <= sizeof(value_type)); // replacement of host std::copy_n() @@ -61,22 +63,22 @@ struct get_carrier<3> // method to trigger template substitution failure __device__ carrier(const carrier& other) noexcept { - copy_n(other.bytes.begin(), bytes.size(), bytes.begin()); + copy_n(other.bytes.begin(), bytes.Size(), bytes.begin()); } public: __device__ carrier& operator=(value_type value) noexcept { - copy_n(reinterpret_cast(&value), bytes.size(), bytes.begin()); + copy_n(reinterpret_cast(&value), bytes.Size(), bytes.begin()); return *this; } __device__ operator value_type() const noexcept { - std::byte result[sizeof(value_type)]; + ck::byte result[sizeof(value_type)]; - copy_n(bytes.begin(), bytes.size(), result); + copy_n(bytes.begin(), bytes.Size(), result); return *reinterpret_cast(result); } @@ -109,8 +111,8 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value) { constexpr unsigned object_size = sizeof(int64_t); constexpr unsigned second_part_offset = object_size / 2; - auto* const from_obj = reinterpret_cast(&value); - alignas(int64_t) std::byte to_obj[object_size]; + auto* const from_obj = reinterpret_cast(&value); + alignas(int64_t) ck::byte to_obj[object_size]; using Sgpr = uint32_t; @@ -122,17 +124,16 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value) return *reinterpret_cast(to_obj); } -template < - typename Object, - typename = std::enable_if_t && std::is_trivially_copyable_v>> +template && ck::is_trivially_copyable_v>> __device__ auto amd_wave_read_first_lane(const Object& obj) { using Size = unsigned; constexpr Size SgprSize = 4; constexpr Size ObjectSize = sizeof(Object); - auto* const from_obj = reinterpret_cast(&obj); - alignas(Object) std::byte to_obj[ObjectSize]; + auto* const from_obj = reinterpret_cast(&obj); + alignas(Object) ck::byte to_obj[ObjectSize]; constexpr Size RemainedSize = ObjectSize % SgprSize; constexpr Size CompleteSgprCopyBoundary = ObjectSize - RemainedSize; diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index a955279bc849c7958c61b789c58d2445c5fd8894..b125e3adf63a5b50a63a3f8f62c2417eacc8b2dd 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -4,8 +4,8 @@ #pragma once namespace ck { -// Define the common macro for gfx94x models -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +// Define the common macro for MI300 models +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__) #define __gfx94__ #endif @@ -134,6 +134,46 @@ struct intrin_mfma_f32_32x32x4f16<32, 64> } }; +template +struct intrin_mfma_f32_32x32x16f16; + +template <> +struct intrin_mfma_f32_32x32x16f16<32, 32> +{ + template + __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + +template +struct intrin_mfma_f32_16x16x32f16; + +template <> +struct intrin_mfma_f32_16x16x32f16<16, 16> +{ + template + __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + template struct intrin_mfma_f32_32x32x8f16; @@ -204,6 +244,46 @@ struct intrin_mfma_f32_4x4x4f16<8, 64> }; // bfp16 +template +struct intrin_mfma_f32_32x32x16bf16; + +template <> +struct intrin_mfma_f32_32x32x16bf16<32, 32> +{ + template + __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + +template +struct intrin_mfma_f32_16x16x32bf16; + +template <> +struct intrin_mfma_f32_16x16x32bf16<16, 16> +{ + template + __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + template struct intrin_mfma_f32_32x32x8bf16_1k; @@ -298,6 +378,46 @@ struct intrin_mfma_i32_16x16x16i8<16, 16> } }; +template +struct intrin_mfma_i32_32x32x32i8; + +template <> +struct intrin_mfma_i32_32x32x32i8<32, 32> +{ + template + __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_i32_32x32x32_i8( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + +template +struct intrin_mfma_i32_16x16x64i8; + +template <> +struct intrin_mfma_i32_16x16x64i8<16, 16> +{ + template + __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_i32_16x16x64_i8( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + template struct intrin_mfma_i32_32x32x16i8; @@ -356,6 +476,149 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> } }; +template +struct intrin_mfma_f32_32x32x64f8f6f4; + +/// @brief Performs a matrix fused multiply-accumulate operation on 32x32x64 submatrices for f8, f6, +/// and f4 data types. +/// +/// @note Calls scaled version of the instruction as the original instruction is not supported in +/// the backend. That is the intended use. There is a backend optimization to select the unscaled +/// operation if the scale is 0. +template <> +struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> +{ + template + __device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz + 0, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +template +struct intrin_mfma_scale_f32_32x32x64f8f6f4; + +template <> +struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> +{ + template + __device__ static void Run(const f8x32_t& reg_a, + const int32_t scale_a, + const f8x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz + 0, // blgp + 0, // { OPSEL_HI[0], OPSEL[0] }? + scale_a, + 0, // { OPSEL_HI[1], OPSEL[1] }? + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } +}; + +template +struct intrin_mfma_scale_f32_16x16x128f8f6f4; + +template <> +struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> +{ + template + __device__ static void Run(const f8x32_t& reg_a, + const int32_t scale_a, + const f8x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz + 0, // blgp + 0, // { OPSEL_HI[0], OPSEL[0] }? + scale_a, + 0, // { OPSEL_HI[1], OPSEL[1] }? + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } +}; + +template +struct intrin_mfma_f32_16x16x128f8f6f4; + +/// @brief Performs a matrix fused multiply-accumulate operation on 16x16x128 submatrices for f8f6f4 +/// data types. +/// +/// @note Calls scaled version of the instruction as the original instruction is not supported in +/// the backend. That is the intended use. There is a backend optimization to select the unscaled +/// operation if the scale is 0. +template <> +struct intrin_mfma_f32_16x16x128f8f6f4<16, 16> +{ + template + __device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz + 0, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + template struct intrin_mfma_f32_32x32x16f8f8; diff --git a/include/ck/utility/array.hpp b/include/ck/utility/array.hpp index 5366c56a9dfa7275ecca75d41daaf1a5cba6333d..2afad00d497f840af8221ef66ae8ec24de7e23ec 100644 --- a/include/ck/utility/array.hpp +++ b/include/ck/utility/array.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_ARRAY_HPP #define CK_ARRAY_HPP @@ -38,6 +38,8 @@ struct Array } __host__ __device__ constexpr const TData* begin() const { return &mData[0]; } __host__ __device__ constexpr const TData* end() const { return &mData[NSize]; } + __host__ __device__ constexpr TData* begin() { return &mData[0]; } + __host__ __device__ constexpr TData* end() { return &mData[NSize]; } }; // empty Array @@ -54,7 +56,7 @@ template __host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs) { using data_type = remove_cvref_t; - return Array{std::forward(x), std::forward(xs)...}; + return Array{ck::forward(x), ck::forward(xs)...}; } // make empty array diff --git a/include/ck/utility/blkgemmpipe_scheduler.hpp b/include/ck/utility/blkgemmpipe_scheduler.hpp index 902195e2f89395ce17b50be4cb3a5f85700299e0..86dcb6c15765347e16d184a76fa22ddcaa53c2c0 100644 --- a/include/ck/utility/blkgemmpipe_scheduler.hpp +++ b/include/ck/utility/blkgemmpipe_scheduler.hpp @@ -90,14 +90,22 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst KPerXDL); printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: " - "%d, %d\n C MFMA inst: %d\n", + "%d, %d\n C MFMA inst: %d\n" + "A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: " + "%d/ %d\n", A_Buffer_Load_Inst_Num, B_Buffer_Load_Inst_Num, A_LDS_Write_Inst_Num, B_LDS_Write_Inst_Num, A_LDS_Read_Inst_Num, B_LDS_Read_Inst_Num, - C_MFMA_Inst_Num); + C_MFMA_Inst_Num, + A_LDS_Read_Width, + B_LDS_Read_Width, + ALDSWriteWidth, + BLDSWriteWidth, + ABufferLoadWidth, + BBufferLoadWidth); } }; diff --git a/include/ck/utility/container_helper.hpp b/include/ck/utility/container_helper.hpp index 9c7b954565d386a8fdecd21052b102e750ab7102..bd0ca42ecddb736aa002151969bf83bf093691ff 100644 --- a/include/ck/utility/container_helper.hpp +++ b/include/ck/utility/container_helper.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_CONTAINER_HELPER_HPP #define CK_CONTAINER_HELPER_HPP @@ -326,14 +326,14 @@ template __host__ __device__ constexpr auto container_concat(const Array& ax, const Array& ay) { return unpack2( - [&](auto&&... zs) { return make_array(std::forward(zs)...); }, ax, ay); + [&](auto&&... zs) { return make_array(ck::forward(zs)...); }, ax, ay); } template __host__ __device__ constexpr auto container_concat(const Tuple& tx, const Tuple& ty) { return unpack2( - [&](auto&&... zs) { return make_tuple(std::forward(zs)...); }, tx, ty); + [&](auto&&... zs) { return make_tuple(ck::forward(zs)...); }, tx, ty); } template diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 39f532e0e94609402ce3a7685d0d0f0b5ce32e25..f90fcf67915e6b75ae31b3d10a38b8f8dac23164 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -1,17 +1,328 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/utility/amd_ck_fp8.hpp" +#include "ck/utility/e8m0.hpp" #include "ck/utility/statically_indexed_array.hpp" - +#ifdef CK_CODE_GEN_RTC +using int8_t = signed char; +using uint8_t = unsigned char; +using int16_t = signed short; +using uint16_t = unsigned short; +using float_t = float; +#endif namespace ck { +#ifdef CK_CODE_GEN_RTC +using byte = unsigned char; +#else +using std::byte; +#endif + using bhalf_t = ushort; using half_t = _Float16; using int4_t = _BitInt(4); -using f8_t = _BitInt(8); -using bf8_t = unsigned _BitInt(8); +using f4_t = unsigned _BitInt(4); +using f6_t = _BitInt(6); // e2m3 format +using bf6_t = unsigned _BitInt(6); // e3m2 format + +struct f4x2_pk_t +{ + using type = uint8_t; + type data; + f4x2_pk_t() : data{type{}} {} + f4x2_pk_t(type init) : data{init} {} + + template + __host__ __device__ inline type unpack(Number) const + { + static_assert(I < 2, "Index is out of range."); + if constexpr(I == 0) + return data & 0b00001111; + else + return (data >> 4); + } + + __host__ __device__ inline type pack(const type x0, const type x1) + { + return (x1 << 4) | (x0 & 0b00001111); + } +}; + +struct f6x16_pk_t +{ + // store 16 elements of f6_t in an array of 3 uint32_t + using element_type = uint32_t; + using type = StaticallyIndexedArray_v2; + type data; + typedef int8_t test_vec_t __attribute__((ext_vector_type(16))); + f6x16_pk_t() : data{type{}} {} + f6x16_pk_t(type init) : data{init} {} + + template + __host__ __device__ inline f6_t unpack(Number) + { + static_assert(I < 16, "Index out of range for 16 f6_t elements."); + + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 3; + constexpr int bit_pos = I * num_bits_elem; + constexpr int arr_idx = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + uint32_t bits = data.At(Number{}) >> bit_offset; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + + if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) + { + bits |= (data.At(Number{}) & ((1u << overhang) - 1)) + << (num_bits_elem - overhang); + } + + return static_cast(bits & 0x3F); + } + + __host__ __device__ inline type pack(const test_vec_t& x) + { + type packed{}; + + // for each of the 16 f6_t values, place its 6 bits in the correct position + ck::static_for<0, 16, 1>{}([&](auto i) { + uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 3; + constexpr int bit_pos = i * num_bits_elem; + constexpr int arr_index = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + uint32_t old_value = packed.At(Number{}); + + // insert bits into the current 32-bit block + old_value |= (bits << bit_offset); + packed.At(Number{}) = old_value; + + // if it crosses into the next block, shift the remainder + if constexpr(overhang > 0 && (arr_index + 1) < vector_size) + { + uint32_t next_value = packed.At(Number{}); + next_value |= (bits >> (num_bits_elem - overhang)); + packed.At(Number{}) = next_value; + } + }); + + return packed; + } +}; + +struct f6x32_pk_t +{ + // store 32 elements of f6_t in an array of 6 uint32_t + using element_type = uint32_t; + using type = StaticallyIndexedArray_v2; + type data; + typedef int8_t test_vec_t __attribute__((ext_vector_type(32))); + f6x32_pk_t() : data{type{}} {} + f6x32_pk_t(type init) : data{init} {} + + template + __host__ __device__ inline f6_t unpack(Number) + { + static_assert(I < 32, "Index out of range for 32 f6_t elements."); + + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 6; + constexpr int bit_pos = I * num_bits_elem; + constexpr int arr_idx = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + uint32_t bits = data.At(Number{}) >> bit_offset; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + + if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) + { + bits |= (data.At(Number{}) & ((1u << overhang) - 1)) + << (num_bits_elem - overhang); + } + + return static_cast(bits & 0x3F); + } + + __host__ __device__ inline type pack(const test_vec_t& x) + { + type packed{}; + + // for each of the 32 f6_t values, place its 6 bits in the correct position + ck::static_for<0, 32, 1>{}([&](auto i) { + uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 6; + constexpr int bit_pos = i * num_bits_elem; + constexpr int arr_index = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + uint32_t old_value = packed.At(Number{}); + + // insert bits into the current 32-bit block + old_value |= (bits << bit_offset); + packed.At(Number{}) = old_value; + + // if it crosses into the next block, shift the remainder + if constexpr(overhang > 0 && (arr_index + 1) < vector_size) + { + uint32_t next_value = packed.At(Number{}); + next_value |= (bits >> (num_bits_elem - overhang)); + packed.At(Number{}) = next_value; + } + }); + + return packed; + } +}; + +struct bf6x16_pk_t +{ + // store 16 elements of bf6_t in an array of 3 uint32_t + using element_type = uint32_t; + using type = StaticallyIndexedArray_v2; + type data; + typedef int8_t test_vec_t __attribute__((ext_vector_type(16))); + bf6x16_pk_t() : data{type{}} {} + bf6x16_pk_t(type init) : data{init} {} + + template + __host__ __device__ inline bf6_t unpack(Number) + { + static_assert(I < 16, "Index out of range for 16 f6_t elements."); + + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 3; + constexpr int bit_pos = I * num_bits_elem; + constexpr int arr_idx = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + uint32_t bits = data.At(Number{}) >> bit_offset; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + + if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) + { + bits |= (data.At(Number{}) & ((1u << overhang) - 1)) + << (num_bits_elem - overhang); + } + + return static_cast(bits & 0x3F); + } + + __host__ __device__ inline type pack(const test_vec_t& x) + { + type packed{}; + + // for each of the 16 bf6_t values, place its 6 bits in the correct position + ck::static_for<0, 16, 1>{}([&](auto i) { + uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 3; + constexpr int bit_pos = i * num_bits_elem; + constexpr int arr_index = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + uint32_t old_value = packed.At(Number{}); + + // insert bits into the current 32-bit block + old_value |= (bits << bit_offset); + packed.At(Number{}) = old_value; + + // if it crosses into the next block, shift the remainder + if constexpr(overhang > 0 && (arr_index + 1) < vector_size) + { + uint32_t next_value = packed.At(Number{}); + next_value |= (bits >> (num_bits_elem - overhang)); + packed.At(Number{}) = next_value; + } + }); + + return packed; + } +}; + +struct bf6x32_pk_t +{ + // store 32 elements of bf6_t in an array of 6 uint32_t + using element_type = uint32_t; + using type = StaticallyIndexedArray_v2; + type data; + typedef int8_t test_vec_t __attribute__((ext_vector_type(32))); + bf6x32_pk_t() : data{type{}} {} + bf6x32_pk_t(type init) : data{init} {} + + template + __host__ __device__ inline bf6_t unpack(Number) + { + static_assert(I < 32, "Index out of range for 32 f6_t elements."); + + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 6; + constexpr int bit_pos = I * num_bits_elem; + constexpr int arr_idx = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + uint32_t bits = data.At(Number{}) >> bit_offset; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + + if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) + { + bits |= (data.At(Number{}) & ((1u << overhang) - 1)) + << (num_bits_elem - overhang); + } + + return static_cast(bits & 0x3F); + } + + __host__ __device__ inline type pack(const test_vec_t& x) + { + type packed{}; + + // for each of the 32 bf6_t values, place its 6 bits in the correct position + ck::static_for<0, 32, 1>{}([&](auto i) { + uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 6; + constexpr int bit_pos = i * num_bits_elem; + constexpr int arr_index = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + uint32_t old_value = packed.At(Number{}); + + // insert bits into the current 32-bit block + old_value |= (bits << bit_offset); + packed.At(Number{}) = old_value; + + // if it crosses into the next block, shift the remainder + if constexpr(overhang > 0 && (arr_index + 1) < vector_size) + { + uint32_t next_value = packed.At(Number{}); + next_value |= (bits >> (num_bits_elem - overhang)); + packed.At(Number{}) = next_value; + } + }); + + return packed; + } +}; + +// custom data type - pack int4 data +struct pk_i4_t +{ + using type = int8_t; + type data; + __host__ __device__ constexpr pk_i4_t() : data{type{}} {} + __host__ __device__ constexpr pk_i4_t(type init) : data{init} {} +}; inline constexpr auto next_pow2(uint32_t x) { @@ -19,14 +330,16 @@ inline constexpr auto next_pow2(uint32_t x) return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x; } -// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_t, bf8_t, bool +// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t, +// native types: bool, f4_t, f6_t, bf6_t template inline constexpr bool is_native_type() { return is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value || - is_same::value; + is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value; } // vector_type @@ -166,16 +479,37 @@ struct scalar_type #endif template <> -struct scalar_type +struct scalar_type +{ + using type = pk_i4_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = f8_fnuz_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = bf8_fnuz_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type { - using type = f8_t; + using type = f8_ocp_t::data_type; static constexpr index_t vector_size = 1; }; template <> -struct scalar_type +struct scalar_type { - using type = bf8_t; + using type = bf8_ocp_t::data_type; static constexpr index_t vector_size = 1; }; @@ -187,7 +521,7 @@ struct scalar_type }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using type = d1_t; @@ -223,7 +557,7 @@ struct vector_type()>> __device__ int static err = 0; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -283,20 +617,20 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d3_t __attribute__((ext_vector_type(3))); - using type = d4_t; + using type = d3_t; union { - d4_t d4_; - StaticallyIndexedArray d1x4_; - StaticallyIndexedArray d2x2_; - StaticallyIndexedArray d4x1_; + d3_t d3_; + StaticallyIndexedArray d1x3_; + StaticallyIndexedArray d2x1_; + StaticallyIndexedArray d3x1_; } data_; __host__ __device__ constexpr vector_type() : data_{type{0}} {} @@ -306,20 +640,20 @@ struct vector_type()>> template __host__ __device__ constexpr const auto& AsType() const { - static_assert(is_same::value || is_same::value || is_same::value, + static_assert(is_same::value || is_same::value || is_same::value, "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { - return data_.d1x4_; + return data_.d1x3_; } else if constexpr(is_same::value) { - return data_.d2x2_; + return data_.d2x1_; } - else if constexpr(is_same::value) + else if constexpr(is_same::value) { - return data_.d4x1_; + return data_.d3x1_; } else { @@ -330,20 +664,20 @@ struct vector_type()>> template __host__ __device__ constexpr auto& AsType() { - static_assert(is_same::value || is_same::value || is_same::value, + static_assert(is_same::value || is_same::value || is_same::value, "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { - return data_.d1x4_; + return data_.d1x3_; } else if constexpr(is_same::value) { - return data_.d2x2_; + return data_.d2x1_; } - else if constexpr(is_same::value) + else if constexpr(is_same::value) { - return data_.d4x1_; + return data_.d3x1_; } else { @@ -353,22 +687,20 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - using type = d8_t; + using type = d4_t; union { - d8_t d8_; - StaticallyIndexedArray d1x8_; - StaticallyIndexedArray d2x4_; - StaticallyIndexedArray d4x2_; - StaticallyIndexedArray d8x1_; + d4_t d4_; + StaticallyIndexedArray d1x4_; + StaticallyIndexedArray d2x2_; + StaticallyIndexedArray d4x1_; } data_; __host__ __device__ constexpr vector_type() : data_{type{0}} {} @@ -378,25 +710,20 @@ struct vector_type()>> template __host__ __device__ constexpr const auto& AsType() const { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, + static_assert(is_same::value || is_same::value || is_same::value, "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { - return data_.d1x8_; + return data_.d1x4_; } else if constexpr(is_same::value) { - return data_.d2x4_; + return data_.d2x2_; } else if constexpr(is_same::value) { - return data_.d4x2_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; + return data_.d4x1_; } else { @@ -407,25 +734,20 @@ struct vector_type()>> template __host__ __device__ constexpr auto& AsType() { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, + static_assert(is_same::value || is_same::value || is_same::value, "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { - return data_.d1x8_; + return data_.d1x4_; } else if constexpr(is_same::value) { - return data_.d2x4_; + return data_.d2x2_; } else if constexpr(is_same::value) { - return data_.d4x2_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; + return data_.d4x1_; } else { @@ -435,24 +757,20 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d5_t __attribute__((ext_vector_type(5))); - using type = d16_t; + using type = d5_t; union { - d16_t d16_; - StaticallyIndexedArray d1x16_; - StaticallyIndexedArray d2x8_; - StaticallyIndexedArray d4x4_; - StaticallyIndexedArray d8x2_; - StaticallyIndexedArray d16x1_; + d5_t d5_; + StaticallyIndexedArray d1x5_; + StaticallyIndexedArray d4x1_; + StaticallyIndexedArray d5x1_; } data_; __host__ __device__ constexpr vector_type() : data_{type{0}} {} @@ -462,30 +780,20 @@ struct vector_type()>> template __host__ __device__ constexpr const auto& AsType() const { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, + static_assert(is_same::value || is_same::value || is_same::value, "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { - return data_.d1x16_; - } - else if constexpr(is_same::value) - { - return data_.d2x8_; + return data_.d1x5_; } else if constexpr(is_same::value) { - return data_.d4x4_; - } - else if constexpr(is_same::value) - { - return data_.d8x2_; + return data_.d4x1_; } - else if constexpr(is_same::value) + else if constexpr(is_same::value) { - return data_.d16x1_; + return data_.d5x1_; } else { @@ -496,30 +804,20 @@ struct vector_type()>> template __host__ __device__ constexpr auto& AsType() { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, + static_assert(is_same::value || is_same::value || is_same::value, "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { - return data_.d1x16_; - } - else if constexpr(is_same::value) - { - return data_.d2x8_; + return data_.d1x5_; } else if constexpr(is_same::value) { - return data_.d4x4_; - } - else if constexpr(is_same::value) - { - return data_.d8x2_; + return data_.d4x1_; } - else if constexpr(is_same::value) + else if constexpr(is_same::value) { - return data_.d16x1_; + return data_.d5x1_; } else { @@ -529,26 +827,22 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d7_t __attribute__((ext_vector_type(7))); - using type = d32_t; + using type = d7_t; union { - d32_t d32_; - StaticallyIndexedArray d1x32_; - StaticallyIndexedArray d2x16_; - StaticallyIndexedArray d4x8_; - StaticallyIndexedArray d8x4_; - StaticallyIndexedArray d16x2_; - StaticallyIndexedArray d32x1_; + d7_t d7_; + StaticallyIndexedArray d1x7_; + StaticallyIndexedArray d2x3_; + StaticallyIndexedArray d4x1_; + StaticallyIndexedArray d7x1_; } data_; __host__ __device__ constexpr vector_type() : data_{type{0}} {} @@ -559,33 +853,24 @@ struct vector_type()>> __host__ __device__ constexpr const auto& AsType() const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, + is_same::value || is_same::value, "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { - return data_.d1x32_; + return data_.d1x7_; } else if constexpr(is_same::value) { - return data_.d2x16_; + return data_.d2x3_; } else if constexpr(is_same::value) { - return data_.d4x8_; - } - else if constexpr(is_same::value) - { - return data_.d8x4_; - } - else if constexpr(is_same::value) - { - return data_.d16x2_; + return data_.d4x1_; } - else if constexpr(is_same::value) + else if constexpr(is_same::value) { - return data_.d32x1_; + return data_.d7x1_; } else { @@ -597,33 +882,24 @@ struct vector_type()>> __host__ __device__ constexpr auto& AsType() { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, + is_same::value || is_same::value, "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { - return data_.d1x32_; + return data_.d1x7_; } else if constexpr(is_same::value) { - return data_.d2x16_; + return data_.d2x3_; } else if constexpr(is_same::value) { - return data_.d4x8_; - } - else if constexpr(is_same::value) - { - return data_.d8x4_; - } - else if constexpr(is_same::value) - { - return data_.d16x2_; + return data_.d4x1_; } - else if constexpr(is_same::value) + else if constexpr(is_same::value) { - return data_.d32x1_; + return data_.d7x1_; } else { @@ -633,28 +909,22 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d4_t __attribute__((ext_vector_type(4))); typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - typedef T d32_t __attribute__((ext_vector_type(32))); - typedef T d64_t __attribute__((ext_vector_type(64))); - using type = d64_t; + using type = d8_t; union { - d64_t d64_; - StaticallyIndexedArray d1x64_; - StaticallyIndexedArray d2x32_; - StaticallyIndexedArray d4x16_; - StaticallyIndexedArray d8x8_; - StaticallyIndexedArray d16x4_; - StaticallyIndexedArray d32x2_; - StaticallyIndexedArray d64x1_; + d8_t d8_; + StaticallyIndexedArray d1x8_; + StaticallyIndexedArray d2x4_; + StaticallyIndexedArray d4x2_; + StaticallyIndexedArray d8x1_; } data_; __host__ __device__ constexpr vector_type() : data_{type{0}} {} @@ -665,81 +935,135 @@ struct vector_type()>> __host__ __device__ constexpr const auto& AsType() const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, + is_same::value || is_same::value, "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { - return data_.d1x64_; + return data_.d1x8_; } else if constexpr(is_same::value) { - return data_.d2x32_; + return data_.d2x4_; } else if constexpr(is_same::value) { - return data_.d4x16_; + return data_.d4x2_; } else if constexpr(is_same::value) { - return data_.d8x8_; + return data_.d8x1_; } - else if constexpr(is_same::value) + else { - return data_.d16x4_; + return err; } - else if constexpr(is_same::value) + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) { - return data_.d32x2_; + return data_.d1x8_; } - else if constexpr(is_same::value) + else if constexpr(is_same::value) { - return data_.d64x1_; + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; } else { return err; } } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d13_t __attribute__((ext_vector_type(13))); + + using type = d13_t; + + union + { + d13_t d13_; + StaticallyIndexedArray d1x13_; + StaticallyIndexedArray d4x3_; + StaticallyIndexedArray d8x1_; + StaticallyIndexedArray d13x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ constexpr const auto& AsType() const { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { - return data_.d1x64_; - } - else if constexpr(is_same::value) - { - return data_.d2x32_; + return data_.d1x13_; } else if constexpr(is_same::value) { - return data_.d4x16_; + return data_.d4x3_; } else if constexpr(is_same::value) { - return data_.d8x8_; + return data_.d8x1_; } - else if constexpr(is_same::value) + else if constexpr(is_same::value) { - return data_.d16x4_; + return data_.d13x1_; } - else if constexpr(is_same::value) + else { - return data_.d32x2_; + return err; } - else if constexpr(is_same::value) + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) { - return data_.d64x1_; + return data_.d1x13_; + } + else if constexpr(is_same::value) + { + return data_.d4x3_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else if constexpr(is_same::value) + { + return data_.d13x1_; } else { @@ -749,30 +1073,24 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d4_t __attribute__((ext_vector_type(4))); typedef T d8_t __attribute__((ext_vector_type(8))); typedef T d16_t __attribute__((ext_vector_type(16))); - typedef T d32_t __attribute__((ext_vector_type(32))); - typedef T d64_t __attribute__((ext_vector_type(64))); - typedef T d128_t __attribute__((ext_vector_type(128))); - using type = d128_t; + using type = d16_t; union { - d128_t d128_; - StaticallyIndexedArray d1x128_; - StaticallyIndexedArray d2x64_; - StaticallyIndexedArray d4x32_; - StaticallyIndexedArray d8x16_; - StaticallyIndexedArray d16x8_; - StaticallyIndexedArray d32x4_; - StaticallyIndexedArray d64x2_; - StaticallyIndexedArray d128x1_; + d16_t d16_; + StaticallyIndexedArray d1x16_; + StaticallyIndexedArray d2x8_; + StaticallyIndexedArray d4x4_; + StaticallyIndexedArray d8x2_; + StaticallyIndexedArray d16x1_; } data_; __host__ __device__ constexpr vector_type() : data_{type{0}} {} @@ -784,41 +1102,28 @@ struct vector_type()>> { static_assert(is_same::value || is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, + is_same::value, "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { - return data_.d1x128_; + return data_.d1x16_; } else if constexpr(is_same::value) { - return data_.d2x64_; + return data_.d2x8_; } else if constexpr(is_same::value) { - return data_.d4x32_; + return data_.d4x4_; } else if constexpr(is_same::value) { - return data_.d8x16_; + return data_.d8x2_; } else if constexpr(is_same::value) { - return data_.d16x8_; - } - else if constexpr(is_same::value) - { - return data_.d32x4_; - } - else if constexpr(is_same::value) - { - return data_.d64x2_; - } - else if constexpr(is_same::value) - { - return data_.d128x1_; + return data_.d16x1_; } else { @@ -831,41 +1136,28 @@ struct vector_type()>> { static_assert(is_same::value || is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, + is_same::value, "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { - return data_.d1x128_; + return data_.d1x16_; } else if constexpr(is_same::value) { - return data_.d2x64_; + return data_.d2x8_; } else if constexpr(is_same::value) { - return data_.d4x32_; + return data_.d4x4_; } else if constexpr(is_same::value) { - return data_.d8x16_; + return data_.d8x2_; } else if constexpr(is_same::value) { - return data_.d16x8_; - } - else if constexpr(is_same::value) - { - return data_.d32x4_; - } - else if constexpr(is_same::value) - { - return data_.d64x2_; - } - else if constexpr(is_same::value) - { - return data_.d128x1_; + return data_.d16x1_; } else { @@ -875,7 +1167,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -883,24 +1175,18 @@ struct vector_type()>> typedef T d8_t __attribute__((ext_vector_type(8))); typedef T d16_t __attribute__((ext_vector_type(16))); typedef T d32_t __attribute__((ext_vector_type(32))); - typedef T d64_t __attribute__((ext_vector_type(64))); - typedef T d128_t __attribute__((ext_vector_type(128))); - typedef T d256_t __attribute__((ext_vector_type(256))); - using type = d256_t; + using type = d32_t; union { - d256_t d256_; - StaticallyIndexedArray d1x256_; - StaticallyIndexedArray d2x128_; - StaticallyIndexedArray d4x64_; - StaticallyIndexedArray d8x32_; - StaticallyIndexedArray d16x16_; - StaticallyIndexedArray d32x8_; - StaticallyIndexedArray d64x4_; - StaticallyIndexedArray d128x2_; - StaticallyIndexedArray d256x1_; + d32_t d32_; + StaticallyIndexedArray d1x32_; + StaticallyIndexedArray d2x16_; + StaticallyIndexedArray d4x8_; + StaticallyIndexedArray d8x4_; + StaticallyIndexedArray d16x2_; + StaticallyIndexedArray d32x1_; } data_; __host__ __device__ constexpr vector_type() : data_{type{0}} {} @@ -910,47 +1196,34 @@ struct vector_type()>> template __host__ __device__ constexpr const auto& AsType() const { - static_assert( - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { - return data_.d1x256_; + return data_.d1x32_; } else if constexpr(is_same::value) { - return data_.d2x128_; + return data_.d2x16_; } else if constexpr(is_same::value) { - return data_.d4x64_; + return data_.d4x8_; } else if constexpr(is_same::value) { - return data_.d8x32_; + return data_.d8x4_; } else if constexpr(is_same::value) { - return data_.d16x16_; + return data_.d16x2_; } else if constexpr(is_same::value) { - return data_.d32x8_; - } - else if constexpr(is_same::value) - { - return data_.d64x4_; - } - else if constexpr(is_same::value) - { - return data_.d128x2_; - } - else if constexpr(is_same::value) - { - return data_.d256x1_; + return data_.d32x1_; } else { @@ -961,47 +1234,34 @@ struct vector_type()>> template __host__ __device__ constexpr auto& AsType() { - static_assert( - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { - return data_.d1x256_; + return data_.d1x32_; } else if constexpr(is_same::value) { - return data_.d2x128_; + return data_.d2x16_; } else if constexpr(is_same::value) { - return data_.d4x64_; + return data_.d4x8_; } else if constexpr(is_same::value) { - return data_.d8x32_; + return data_.d8x4_; } else if constexpr(is_same::value) { - return data_.d16x16_; + return data_.d16x2_; } else if constexpr(is_same::value) { - return data_.d32x8_; - } - else if constexpr(is_same::value) - { - return data_.d64x4_; - } - else if constexpr(is_same::value) - { - return data_.d128x2_; - } - else if constexpr(is_same::value) - { - return data_.d256x1_; + return data_.d32x1_; } else { @@ -1010,60 +1270,677 @@ struct vector_type()>> } }; -template -struct non_native_vector_base -{ - using type = non_native_vector_base; - - __host__ __device__ non_native_vector_base() = default; - __host__ __device__ non_native_vector_base(const type&) = default; - __host__ __device__ non_native_vector_base(type&&) = default; - __host__ __device__ ~non_native_vector_base() = default; - - T d[N]; -}; - -// non-native vector_type implementation template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; - using type = d1_t; - - union alignas(next_pow2(1 * sizeof(T))) - { + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + + using type = d64_t; + + union + { + d64_t d64_; + StaticallyIndexedArray d1x64_; + StaticallyIndexedArray d2x32_; + StaticallyIndexedArray d4x16_; + StaticallyIndexedArray d8x8_; + StaticallyIndexedArray d16x4_; + StaticallyIndexedArray d32x2_; + StaticallyIndexedArray d64x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + typedef T d128_t __attribute__((ext_vector_type(128))); + + using type = d128_t; + + union + { + d128_t d128_; + StaticallyIndexedArray d1x128_; + StaticallyIndexedArray d2x64_; + StaticallyIndexedArray d4x32_; + StaticallyIndexedArray d8x16_; + StaticallyIndexedArray d16x8_; + StaticallyIndexedArray d32x4_; + StaticallyIndexedArray d64x2_; + StaticallyIndexedArray d128x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x128_; + } + else if constexpr(is_same::value) + { + return data_.d2x64_; + } + else if constexpr(is_same::value) + { + return data_.d4x32_; + } + else if constexpr(is_same::value) + { + return data_.d8x16_; + } + else if constexpr(is_same::value) + { + return data_.d16x8_; + } + else if constexpr(is_same::value) + { + return data_.d32x4_; + } + else if constexpr(is_same::value) + { + return data_.d64x2_; + } + else if constexpr(is_same::value) + { + return data_.d128x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x128_; + } + else if constexpr(is_same::value) + { + return data_.d2x64_; + } + else if constexpr(is_same::value) + { + return data_.d4x32_; + } + else if constexpr(is_same::value) + { + return data_.d8x16_; + } + else if constexpr(is_same::value) + { + return data_.d16x8_; + } + else if constexpr(is_same::value) + { + return data_.d32x4_; + } + else if constexpr(is_same::value) + { + return data_.d64x2_; + } + else if constexpr(is_same::value) + { + return data_.d128x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + typedef T d128_t __attribute__((ext_vector_type(128))); + typedef T d256_t __attribute__((ext_vector_type(256))); + + using type = d256_t; + + union + { + d256_t d256_; + StaticallyIndexedArray d1x256_; + StaticallyIndexedArray d2x128_; + StaticallyIndexedArray d4x64_; + StaticallyIndexedArray d8x32_; + StaticallyIndexedArray d16x16_; + StaticallyIndexedArray d32x8_; + StaticallyIndexedArray d64x4_; + StaticallyIndexedArray d128x2_; + StaticallyIndexedArray d256x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert( + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x256_; + } + else if constexpr(is_same::value) + { + return data_.d2x128_; + } + else if constexpr(is_same::value) + { + return data_.d4x64_; + } + else if constexpr(is_same::value) + { + return data_.d8x32_; + } + else if constexpr(is_same::value) + { + return data_.d16x16_; + } + else if constexpr(is_same::value) + { + return data_.d32x8_; + } + else if constexpr(is_same::value) + { + return data_.d64x4_; + } + else if constexpr(is_same::value) + { + return data_.d128x2_; + } + else if constexpr(is_same::value) + { + return data_.d256x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert( + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x256_; + } + else if constexpr(is_same::value) + { + return data_.d2x128_; + } + else if constexpr(is_same::value) + { + return data_.d4x64_; + } + else if constexpr(is_same::value) + { + return data_.d8x32_; + } + else if constexpr(is_same::value) + { + return data_.d16x16_; + } + else if constexpr(is_same::value) + { + return data_.d32x8_; + } + else if constexpr(is_same::value) + { + return data_.d64x4_; + } + else if constexpr(is_same::value) + { + return data_.d128x2_; + } + else if constexpr(is_same::value) + { + return data_.d256x1_; + } + else + { + return err; + } + } +}; + +template +struct non_native_vector_base; + +template +struct nnvb_data_t_selector +{ + using type = unsigned _BitInt(8 * sizeof(T)); +}; + +template <> +struct nnvb_data_t_selector +{ + using type = f8_ocp_t::data_type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = bf8_ocp_t::data_type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = f6x16_pk_t::type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = f6x32_pk_t::type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = bf6x16_pk_t::type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = bf6x32_pk_t::type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = pk_i4_t::type; +}; + +template +struct non_native_vector_base< + T, + N, + ck::enable_if_t> +{ + using data_t = typename nnvb_data_t_selector::type; // select data_t based on the size of T + static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); + using data_v = data_t __attribute__((ext_vector_type(N))); + using type = non_native_vector_base; + + union alignas(next_pow2(N * sizeof(T))) + { + data_v dN; // storage vector; + StaticallyIndexedArray dxN; + StaticallyIndexedArray dTxN; + StaticallyIndexedArray dNx1; + } data_; + + __host__ __device__ constexpr non_native_vector_base(data_t a) : data_{data_v(a)} {} + __host__ __device__ constexpr non_native_vector_base(T f) + : non_native_vector_base(bit_cast(f)) + { + } + __host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){}; + __host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {} + + __host__ __device__ constexpr operator data_v() const { return data_.dN; } + __host__ __device__ constexpr operator data_t() const + { + if constexpr(N == 1) + { + return data_.dxN[Number<0>{}]; + } + else + { + return data_.dxN; // XXX this should cause an error + } + } + __host__ __device__ constexpr operator T() const + { + if constexpr(N == 1) + { + return data_.dTxN[Number<0>{}]; + } + else + { + return data_.dTxN; // XXX this should cause an error + } + } + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same_v || is_same_v || is_same_v, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same_v) + { + return data_.dxN; + } + else if constexpr(is_same_v) + { + return data_.dTxN; + } + else if constexpr(is_same_v) + { + return data_.dNx1; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same_v || is_same_v || is_same_v, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same_v) + { + return data_.dxN; + } + else if constexpr(is_same_v) + { + return data_.dTxN; + } + else if constexpr(is_same_v) + { + return data_.dNx1; + } + else + { + return err; + } + } +}; + +// implementation for f6x16 and f6x32 +template +struct non_native_vector_base> +{ + using data_t = + typename nnvb_data_t_selector::type; // select data_t based on declared base type + using element_t = typename T::element_type; // select element_t based on declared element type + static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); + static constexpr size_t size_factor = + sizeof(data_t) / sizeof(element_t); // f6x16: 12/4 = 3, f6x32: 24/4 = 6 + using data_v = element_t __attribute__((ext_vector_type(N * size_factor))); + using type = non_native_vector_base; + + union alignas(next_pow2(N * sizeof(T))) + { + data_v dN; // storage vector; + StaticallyIndexedArray dxN; + StaticallyIndexedArray dTxN; + StaticallyIndexedArray dNx1; + } data_; + + __host__ __device__ constexpr non_native_vector_base(data_t a) + : data_{data_v(a.At(Number<0>{}))} + { + } + __host__ __device__ constexpr non_native_vector_base(T f) + : non_native_vector_base(bit_cast(f)) + { + } + __host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){}; + __host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {} + + __host__ __device__ constexpr operator data_v() const { return data_.dN; } + __host__ __device__ constexpr operator data_t() const + { + if constexpr(N == 1) + { + return data_.dxN[Number<0>{}]; + } + else + { + return data_.dxN; // XXX this should cause an error + } + } + __host__ __device__ constexpr operator T() const + { + if constexpr(N == 1) + { + return data_.dTxN[Number<0>{}]; + } + else + { + return data_.dTxN; // XXX this should cause an error + } + } +}; + +template +struct scalar_type>; + +template +struct scalar_type> +{ + using type = typename non_native_vector_base::data_t; + + static constexpr index_t vector_size = N; +}; + +template +struct scalar_type> +{ + using type = typename non_native_vector_base::data_t; + + static constexpr index_t vector_size = N; +}; + +template +struct scalar_type> +{ + using type = typename non_native_vector_base::data_t; + + static constexpr index_t vector_size = N; +}; + +// non-native vector_type implementation +template +struct vector_type()>> +{ + using d1_t = T; + using d1_nnv_t = non_native_vector_base; + using type = d1_nnv_t; + + union alignas(next_pow2(1 * sizeof(T))) + { d1_t d1_; StaticallyIndexedArray d1x1_; + d1_nnv_t d1_nnv_; } data_; - __host__ __device__ constexpr vector_type() : data_{type{}} {} + __host__ __device__ constexpr vector_type() : data_{d1_t{}} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {} template __host__ __device__ constexpr const auto& AsType() const { - static_assert(is_same::value, + static_assert(is_same::value || is_same::value, "Something went wrong, please check src and dst types."); - return data_.d1x1_; + if constexpr(is_same::value || is_same::value) + { + return data_.d1x1_; + } + else + { + return err; + } } template __host__ __device__ constexpr auto& AsType() { - static_assert(is_same::value, + static_assert(is_same::value || is_same::value, "Something went wrong, please check src and dst types."); - return data_.d1x1_; + if constexpr(is_same::value || is_same::value) + { + return data_.d1x1_; + } + else + { + return err; + } } }; template -struct vector_type()>> +struct vector_type()>> { - using d1_t = T; - using d2_t = non_native_vector_base; + using d1_t = T; + using d1_nnv_t = non_native_vector_base; + using d2_t = non_native_vector_base; using type = d2_t; @@ -1081,10 +1958,11 @@ struct vector_type()>> template __host__ __device__ constexpr const auto& AsType() const { - static_assert(is_same::value || is_same::value, + static_assert(is_same::value || is_same::value || + is_same::value, "Something went wrong, please check src and dst types."); - if constexpr(is_same::value) + if constexpr(is_same::value || is_same::value) { return data_.d1x2_; } @@ -1101,10 +1979,11 @@ struct vector_type()>> template __host__ __device__ constexpr auto& AsType() { - static_assert(is_same::value || is_same::value, + static_assert(is_same::value || is_same::value || + is_same::value, "Something went wrong, please check src and dst types."); - if constexpr(is_same::value) + if constexpr(is_same::value || is_same::value) { return data_.d1x2_; } @@ -1120,11 +1999,12 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { - using d1_t = T; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; + using d1_t = T; + using d1_nnv_t = non_native_vector_base; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; using type = d4_t; @@ -1143,10 +2023,11 @@ struct vector_type()>> template __host__ __device__ constexpr const auto& AsType() const { - static_assert(is_same::value || is_same::value || is_same::value, + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, "Something went wrong, please check src and dst types."); - if constexpr(is_same::value) + if constexpr(is_same::value || is_same::value) { return data_.d1x4_; } @@ -1167,10 +2048,11 @@ struct vector_type()>> template __host__ __device__ constexpr auto& AsType() { - static_assert(is_same::value || is_same::value || is_same::value, + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, "Something went wrong, please check src and dst types."); - if constexpr(is_same::value) + if constexpr(is_same::value || is_same::value) { return data_.d1x4_; } @@ -1190,12 +2072,13 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { - using d1_t = T; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; - using d8_t = non_native_vector_base; + using d1_t = T; + using d1_nnv_t = non_native_vector_base; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; using type = d8_t; @@ -1215,11 +2098,12 @@ struct vector_type()>> template __host__ __device__ constexpr const auto& AsType() const { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, "Something went wrong, please check src and dst types."); - if constexpr(is_same::value) + if constexpr(is_same::value || is_same::value) { return data_.d1x8_; } @@ -1244,11 +2128,12 @@ struct vector_type()>> template __host__ __device__ constexpr auto& AsType() { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, "Something went wrong, please check src and dst types."); - if constexpr(is_same::value) + if constexpr(is_same::value || is_same::value) { return data_.d1x8_; } @@ -1272,13 +2157,14 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { - using d1_t = T; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; - using d8_t = non_native_vector_base; - using d16_t = non_native_vector_base; + using d1_t = T; + using d1_nnv_t = non_native_vector_base; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + using d16_t = non_native_vector_base; using type = d16_t; @@ -1299,12 +2185,12 @@ struct vector_type()>> template __host__ __device__ constexpr const auto& AsType() const { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, "Something went wrong, please check src and dst types."); - if constexpr(is_same::value) + if constexpr(is_same::value || is_same::value) { return data_.d1x16_; } @@ -1333,12 +2219,12 @@ struct vector_type()>> template __host__ __device__ constexpr auto& AsType() { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, "Something went wrong, please check src and dst types."); - if constexpr(is_same::value) + if constexpr(is_same::value || is_same::value) { return data_.d1x16_; } @@ -1366,7 +2252,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using d2_t = non_native_vector_base; @@ -1470,7 +2356,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using d2_t = non_native_vector_base; @@ -1541,134 +2427,415 @@ struct vector_type()>> } } - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + else + { + return err; + } + } +}; + +using int64_t = long; + +// fp64 +using double2_t = typename vector_type::type; +using double4_t = typename vector_type::type; + +// fp32 +using float2_t = typename vector_type::type; +using float4_t = typename vector_type::type; +using float8_t = typename vector_type::type; +using float16_t = typename vector_type::type; +using float32_t = typename vector_type::type; +using float64_t = typename vector_type::type; + +// fp16 +using half2_t = typename vector_type::type; +using half4_t = typename vector_type::type; +using half8_t = typename vector_type::type; +using half16_t = typename vector_type::type; +using half32_t = typename vector_type::type; +using half64_t = typename vector_type::type; + +// bfp16 +using bhalf2_t = typename vector_type::type; +using bhalf4_t = typename vector_type::type; +using bhalf8_t = typename vector_type::type; +using bhalf16_t = typename vector_type::type; +using bhalf32_t = typename vector_type::type; +using bhalf64_t = typename vector_type::type; + +// i32 +using int32x2_t = typename vector_type::type; +using int32x4_t = typename vector_type::type; +using int32x8_t = typename vector_type::type; +using int32x16_t = typename vector_type::type; +using int32x32_t = typename vector_type::type; +using int32x64_t = typename vector_type::type; + +// i8 +using int8x2_t = typename vector_type::type; +using int8x4_t = typename vector_type::type; +using int8x8_t = typename vector_type::type; +using int8x16_t = typename vector_type::type; +using int8x32_t = typename vector_type::type; +using int8x64_t = typename vector_type::type; + +// f8 +using f8x2_fnuz_t = typename vector_type::type; +using f8x4_fnuz_t = typename vector_type::type; +using f8x8_fnuz_t = typename vector_type::type; +using f8x16_fnuz_t = typename vector_type::type; +using f8x32_fnuz_t = typename vector_type::type; +using f8x64_fnuz_t = typename vector_type::type; + +// bf8 +using bf8x2_fnuz_t = typename vector_type::type; +using bf8x4_fnuz_t = typename vector_type::type; +using bf8x8_fnuz_t = typename vector_type::type; +using bf8x16_fnuz_t = typename vector_type::type; +using bf8x32_fnuz_t = typename vector_type::type; +using bf8x64_fnuz_t = typename vector_type::type; + +// f8 +using f8x2_ocp_t = typename vector_type::type; +using f8x4_ocp_t = typename vector_type::type; +using f8x8_ocp_t = typename vector_type::type; +using f8x16_ocp_t = typename vector_type::type; +using f8x32_ocp_t = typename vector_type::type; +using f8x64_ocp_t = typename vector_type::type; + +// bf8 +using bf8x2_ocp_t = typename vector_type::type; +using bf8x4_ocp_t = typename vector_type::type; +using bf8x8_ocp_t = typename vector_type::type; +using bf8x16_ocp_t = typename vector_type::type; +using bf8x32_ocp_t = typename vector_type::type; +using bf8x64_ocp_t = typename vector_type::type; + +#if CK_FP8_TYPE_OCP +// f8 +using f8x2_t = f8x2_ocp_t; +using f8x4_t = f8x4_ocp_t; +using f8x8_t = f8x8_ocp_t; +using f8x16_t = f8x16_ocp_t; +using f8x32_t = f8x32_ocp_t; +using f8x64_t = f8x64_ocp_t; + +// bf8 +using bf8x2_t = bf8x2_ocp_t; +using bf8x4_t = bf8x4_ocp_t; +using bf8x8_t = bf8x8_ocp_t; +using bf8x16_t = bf8x16_ocp_t; +using bf8x32_t = bf8x32_ocp_t; +using bf8x64_t = bf8x64_ocp_t; +#elif CK_FP8_TYPE_FNUZ +// f8 +using f8x2_t = f8x2_fnuz_t; +using f8x4_t = f8x4_fnuz_t; +using f8x8_t = f8x8_fnuz_t; +using f8x16_t = f8x16_fnuz_t; +using f8x32_t = f8x32_fnuz_t; +using f8x64_t = f8x64_fnuz_t; + +// bf8 +using bf8x2_t = bf8x2_fnuz_t; +using bf8x4_t = bf8x4_fnuz_t; +using bf8x8_t = bf8x8_fnuz_t; +using bf8x16_t = bf8x16_fnuz_t; +using bf8x32_t = bf8x32_fnuz_t; +using bf8x64_t = bf8x64_fnuz_t; +#endif + +// u8 +using uint8x2_t = typename vector_type::type; +using uint8x4_t = typename vector_type::type; +using uint8x8_t = typename vector_type::type; +using uint8x16_t = typename vector_type::type; +using uint8x32_t = typename vector_type::type; +using uint8x64_t = typename vector_type::type; + +// f4 +using f4x2_t = typename vector_type::type; +using f4x4_t = typename vector_type::type; +using f4x8_t = typename vector_type::type; +using f4x16_t = typename vector_type::type; +using f4x32_t = typename vector_type::type; +using f4x64_t = typename vector_type::type; + +// f6 +using f6x16_t = typename vector_type::type; +using f6x32_t = typename vector_type::type; + +// bf6 +using bf6x16_t = typename vector_type::type; +using bf6x32_t = typename vector_type::type; + +// pack int4 +using pk_i4x2_t = typename vector_type::type; +using pk_i4x4_t = typename vector_type::type; +using pk_i4x8_t = typename vector_type::type; + +#ifdef CK_CODE_GEN_RTC +template +struct NumericLimits; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int32_t Lowest() noexcept { return -2147483647 - 1; } + + __host__ __device__ static constexpr int32_t Min() noexcept { return -2147483647 - 1; } + + __host__ __device__ static constexpr int32_t Max() noexcept { return 2147483647; } + + __host__ __device__ static constexpr int32_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr int32_t QuietNaN() { return 0; } +}; +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int16_t Lowest() noexcept { return -32768; } + + __host__ __device__ static constexpr int16_t Min() noexcept { return -32768; } + + __host__ __device__ static constexpr int16_t Max() noexcept { return 32767; } + + __host__ __device__ static constexpr int16_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr int16_t QuietNaN() { return 0; } +}; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int8_t Lowest() noexcept { return -128; } + + __host__ __device__ static constexpr int8_t Min() noexcept { return -128; } + + __host__ __device__ static constexpr int8_t Max() noexcept { return 127; } + + __host__ __device__ static constexpr int8_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr int8_t QuietNaN() { return 0; } +}; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr uint32_t Lowest() noexcept { return 0; } + + __host__ __device__ static constexpr uint32_t Min() noexcept { return 0; } + + __host__ __device__ static constexpr uint32_t Max() noexcept { return 4294967295U; } + + __host__ __device__ static constexpr uint32_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr uint32_t QuietNaN() { return 0; } +}; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr uint16_t Lowest() noexcept { return 0; } + + __host__ __device__ static constexpr uint16_t Min() noexcept { return 0; } + + __host__ __device__ static constexpr uint16_t Max() noexcept { return 65535U; } + + __host__ __device__ static constexpr uint16_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr uint16_t QuietNaN() { return 0; } +}; + +template <> +struct NumericLimits +{ + static constexpr unsigned int binary_min = 0x00800000; + static constexpr unsigned int binary_max = 0x7F7FFFFF; + static constexpr unsigned int binary_lowest = 0xFF7FFFFF; + static constexpr unsigned int binary_qnan = 0xFFC00001; + static constexpr unsigned int binary_inf = 0x7F8000000; + + __host__ __device__ static constexpr float Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr float Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr float Lowest() { return bit_cast(binary_lowest); } + + __host__ __device__ static constexpr float QuietNaN() { return bit_cast(binary_qnan); } + + __host__ __device__ static constexpr float Infinity() { return bit_cast(binary_inf); } +}; + +template <> +struct NumericLimits +{ + static constexpr unsigned short binary_min = 0x0400; + static constexpr unsigned short binary_max = 0x7BFF; + static constexpr unsigned short binary_lowest = 0xFBFF; + static constexpr unsigned short binary_qnan = 0x7FFF; + + __host__ __device__ static constexpr half_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr half_t Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr half_t Lowest() { return bit_cast(binary_lowest); } + + __host__ __device__ static constexpr half_t QuietNaN() { return bit_cast(binary_qnan); } +}; + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int4_t Min() { return int4_t(-8); } + + __host__ __device__ static constexpr int4_t Max() { return int4_t(7); } + + __host__ __device__ static constexpr int4_t Lowest() { return int4_t(-8); } +}; +#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + +template <> +struct NumericLimits +{ + // negative zero nan mode with exp bias = 8 + static constexpr uint8_t binary_min = 0x08; // 0b00001000 + static constexpr uint8_t binary_max = 0x7F; // 0b01111111 + static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 + static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + // ieee mode with exp bias = 7 + // static constexpr uint8_t binary_min = 0x08; // 0b00001000 + // static constexpr uint8_t binary_max = 0x77; // 0b01110111 + // static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 + // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0 - if constexpr(is_same::value) - { - return data_.d1x64_; - } - else if constexpr(is_same::value) - { - return data_.d2x32_; - } - else if constexpr(is_same::value) - { - return data_.d4x16_; - } - else if constexpr(is_same::value) - { - return data_.d8x8_; - } - else if constexpr(is_same::value) - { - return data_.d16x4_; - } - else if constexpr(is_same::value) - { - return data_.d32x2_; - } - else if constexpr(is_same::value) - { - return data_.d64x1_; - } - else - { - return err; - } - } + __host__ __device__ static constexpr f8_fnuz_t Min() { return f8_fnuz_t(binary_min); } + + __host__ __device__ static constexpr f8_fnuz_t Max() { return f8_fnuz_t(binary_max); } + + __host__ __device__ static constexpr f8_fnuz_t Lowest() { return f8_fnuz_t(binary_lowest); } + + __host__ __device__ static constexpr f8_fnuz_t QuietNaN() { return f8_fnuz_t(binary_qnan); } }; -using int64_t = long; +template <> +struct NumericLimits +{ + // negative zero nan mode with exp bias = 16 + static constexpr uint8_t binary_min = 0x04; // 0b00000100 + static constexpr uint8_t binary_max = 0x7F; // 0b01111111 + static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 + static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + // ieee mode with exp bias = 15 + // static constexpr uint8_t binary_min = 0x04; // 0b00000100 + // static constexpr uint8_t binary_max = 0x7B; // 0b01111011 + // static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 + // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!= -// fp64 -using double2_t = typename vector_type::type; -using double4_t = typename vector_type::type; + __host__ __device__ static constexpr bf8_fnuz_t Min() { return bf8_fnuz_t(binary_min); } -// fp32 -using float2_t = typename vector_type::type; -using float4_t = typename vector_type::type; -using float8_t = typename vector_type::type; -using float16_t = typename vector_type::type; -using float32_t = typename vector_type::type; -using float64_t = typename vector_type::type; + __host__ __device__ static constexpr bf8_fnuz_t Max() { return bf8_fnuz_t(binary_max); } -// fp16 -using half2_t = typename vector_type::type; -using half4_t = typename vector_type::type; -using half8_t = typename vector_type::type; -using half16_t = typename vector_type::type; -using half32_t = typename vector_type::type; -using half64_t = typename vector_type::type; + __host__ __device__ static constexpr bf8_fnuz_t Lowest() { return bf8_fnuz_t(binary_lowest); } -// bfp16 -using bhalf2_t = typename vector_type::type; -using bhalf4_t = typename vector_type::type; -using bhalf8_t = typename vector_type::type; -using bhalf16_t = typename vector_type::type; -using bhalf32_t = typename vector_type::type; -using bhalf64_t = typename vector_type::type; + __host__ __device__ static constexpr bf8_fnuz_t QuietNaN() { return bf8_fnuz_t(binary_qnan); } +}; -// i32 -using int32x2_t = typename vector_type::type; -using int32x4_t = typename vector_type::type; -using int32x8_t = typename vector_type::type; -using int32x16_t = typename vector_type::type; -using int32x32_t = typename vector_type::type; -using int32x64_t = typename vector_type::type; +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min = 0x08; // 0b00001000 = 2^-6 + static constexpr uint8_t binary_max = 0x7E; // 0b01111110 = 448 + static constexpr uint8_t binary_lowest = 0xFE; // 0b11111110 = -448 + static constexpr uint8_t binary_qnan = 0x7F; // 0b01111111 -// i8 -using int8x2_t = typename vector_type::type; -using int8x4_t = typename vector_type::type; -using int8x8_t = typename vector_type::type; -using int8x16_t = typename vector_type::type; -using int8x32_t = typename vector_type::type; -using int8x64_t = typename vector_type::type; + __host__ __device__ static constexpr f8_ocp_t Min() { return bit_cast(binary_min); } -// f8 -using f8x2_t = typename vector_type::type; -using f8x4_t = typename vector_type::type; -using f8x8_t = typename vector_type::type; -using f8x16_t = typename vector_type::type; -using f8x32_t = typename vector_type::type; -using f8x64_t = typename vector_type::type; + __host__ __device__ static constexpr f8_ocp_t Max() { return bit_cast(binary_max); } -// bf8 -using bf8x2_t = typename vector_type::type; -using bf8x4_t = typename vector_type::type; -using bf8x8_t = typename vector_type::type; -using bf8x16_t = typename vector_type::type; -using bf8x32_t = typename vector_type::type; -using bf8x64_t = typename vector_type::type; + __host__ __device__ static constexpr f8_ocp_t Lowest() + { + return bit_cast(binary_lowest); + } -// u8 -using uint8x2_t = typename vector_type::type; -using uint8x4_t = typename vector_type::type; -using uint8x8_t = typename vector_type::type; -using uint8x16_t = typename vector_type::type; -using uint8x32_t = typename vector_type::type; -using uint8x64_t = typename vector_type::type; + __host__ __device__ static constexpr f8_ocp_t QuietNaN() + { + return bit_cast(binary_qnan); + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min = 0x04; // 0b00000100 = 2^-14 + static constexpr uint8_t binary_max = 0x7B; // 0b01111011 = 57344 + static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 = -57344 + static constexpr uint8_t binary_qnan = 0x7D; // 0b01111101 + + __host__ __device__ static constexpr bf8_ocp_t Min() { return bit_cast(binary_min); } + __host__ __device__ static constexpr bf8_ocp_t Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr bf8_ocp_t Lowest() + { + return bit_cast(binary_lowest); + } + + __host__ __device__ static constexpr bf8_ocp_t QuietNaN() + { + return bit_cast(binary_qnan); + } +}; +#else template struct NumericLimits { __host__ __device__ static constexpr T Min() { return std::numeric_limits::min(); } - __host__ __device__ static constexpr T Max() { return std::numeric_limits::max(); } - __host__ __device__ static constexpr T Lowest() { return std::numeric_limits::lowest(); } - __host__ __device__ static constexpr T QuietNaN() { return std::numeric_limits::quiet_NaN(); } - __host__ __device__ static constexpr T Infinity() { return std::numeric_limits::infinity(); } }; @@ -1702,7 +2869,7 @@ struct NumericLimits #endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 template <> -struct NumericLimits +struct NumericLimits { // negative zero nan mode with exp bias = 8 static constexpr uint8_t binary_min = 0x08; // 0b00001000 @@ -1715,17 +2882,17 @@ struct NumericLimits // static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0 - __host__ __device__ static constexpr f8_t Min() { return f8_t(binary_min); } + __host__ __device__ static constexpr f8_fnuz_t Min() { return f8_fnuz_t(binary_min); } - __host__ __device__ static constexpr f8_t Max() { return f8_t(binary_max); } + __host__ __device__ static constexpr f8_fnuz_t Max() { return f8_fnuz_t(binary_max); } - __host__ __device__ static constexpr f8_t Lowest() { return f8_t(binary_lowest); } + __host__ __device__ static constexpr f8_fnuz_t Lowest() { return f8_fnuz_t(binary_lowest); } - __host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); } + __host__ __device__ static constexpr f8_fnuz_t QuietNaN() { return f8_fnuz_t(binary_qnan); } }; template <> -struct NumericLimits +struct NumericLimits { // negative zero nan mode with exp bias = 16 static constexpr uint8_t binary_min = 0x04; // 0b00000100 @@ -1738,13 +2905,172 @@ struct NumericLimits // static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!= - __host__ __device__ static constexpr bf8_t Min() { return bf8_t(binary_min); } + __host__ __device__ static constexpr bf8_fnuz_t Min() { return bf8_fnuz_t(binary_min); } + + __host__ __device__ static constexpr bf8_fnuz_t Max() { return bf8_fnuz_t(binary_max); } + + __host__ __device__ static constexpr bf8_fnuz_t Lowest() { return bf8_fnuz_t(binary_lowest); } + + __host__ __device__ static constexpr bf8_fnuz_t QuietNaN() { return bf8_fnuz_t(binary_qnan); } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min = 0x08; // 0b00001000 = 2^-6 + static constexpr uint8_t binary_max = 0x7E; // 0b01111110 = 448 + static constexpr uint8_t binary_lowest = 0xFE; // 0b11111110 = -448 + static constexpr uint8_t binary_qnan = 0x7F; // 0b01111111 + + __host__ __device__ static constexpr f8_ocp_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr f8_ocp_t Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr f8_ocp_t Lowest() + { + return bit_cast(binary_lowest); + } + + __host__ __device__ static constexpr f8_ocp_t QuietNaN() + { + return bit_cast(binary_qnan); + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min = 0x04; // 0b00000100 = 2^-14 + static constexpr uint8_t binary_max = 0x7B; // 0b01111011 = 57344 + static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 = -57344 + static constexpr uint8_t binary_qnan = 0x7D; // 0b01111101 + + __host__ __device__ static constexpr bf8_ocp_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr bf8_ocp_t Max() { return bit_cast(binary_max); } - __host__ __device__ static constexpr bf8_t Max() { return bf8_t(binary_max); } + __host__ __device__ static constexpr bf8_ocp_t Lowest() + { + return bit_cast(binary_lowest); + } + + __host__ __device__ static constexpr bf8_ocp_t QuietNaN() + { + return bit_cast(binary_qnan); + } +}; +#endif + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x2; // 0b0010 + static constexpr uint8_t binary_max_normal = 0x7; // 0b0111 + static constexpr uint8_t binary_lowest_normal = 0xF; // 0b1111 + static constexpr uint8_t binary_min_subnorm = 0x1; // 0b0001 + static constexpr uint8_t binary_max_subnorm = 0x1; // 0b0001 + + static constexpr float data_max_normal_number = 6; + static constexpr float data_min_subnormal_number = 0.5; + + __host__ __device__ static constexpr f4_t Min() { return f4_t(binary_min_normal); } + __host__ __device__ static constexpr f4_t Max() { return f4_t(binary_max_normal); } + __host__ __device__ static constexpr f4_t Lowest() { return f4_t(binary_lowest_normal); } + __host__ __device__ static constexpr f4_t MinSubnorm() { return f4_t(binary_min_subnorm); } + __host__ __device__ static constexpr f4_t MaxSubnorm() { return f4_t(binary_max_subnorm); } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 + static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 + static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 + static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 + static constexpr uint8_t binary_max_subnorm = 0x07; // 0b000111 + + static constexpr float data_max_normal_number = 7.5; + static constexpr float data_min_subnormal_number = 0.125; + + __host__ __device__ static constexpr f6_t Min() { return f6_t(binary_min_normal & 0b111111); } + __host__ __device__ static constexpr f6_t Max() { return f6_t(binary_max_normal & 0b111111); } + __host__ __device__ static constexpr f6_t Lowest() + { + return f6_t(binary_lowest_normal & 0b111111); + } + __host__ __device__ static constexpr f6_t MinSubnorm() + { + return f6_t(binary_min_subnorm & 0b111111); + } + __host__ __device__ static constexpr f6_t MaxSubnorm() + { + return f6_t(binary_max_subnorm & 0b111111); + } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; - __host__ __device__ static constexpr bf8_t Lowest() { return bf8_t(binary_lowest); } +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 + static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 + static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 + static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 + static constexpr uint8_t binary_max_subnorm = 0x03; // 0b000011 + + static constexpr float data_max_normal_number = 28; + static constexpr float data_min_subnormal_number = 0.0625; + + __host__ __device__ static constexpr bf6_t Min() { return bf6_t(binary_min_normal); } + __host__ __device__ static constexpr bf6_t Max() { return bf6_t(binary_max_normal); } + __host__ __device__ static constexpr bf6_t Lowest() { return bf6_t(binary_lowest_normal); } + __host__ __device__ static constexpr bf6_t MinSubnorm() { return bf6_t(binary_min_subnorm); } + __host__ __device__ static constexpr bf6_t MaxSubnorm() { return bf6_t(binary_max_subnorm); } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; - __host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); } +template <> +struct NumericLimits +{ + static constexpr e8m0_bexp_t binary_min = 0x00; // 0b00000000 + static constexpr e8m0_bexp_t binary_max = 0xFE; // 0b11111110 + static constexpr e8m0_bexp_t binary_qnan = 0xFF; // 0b11111111 + static constexpr e8m0_bexp_t binary_1 = 0x7F; // 0b01111111 + static constexpr e8m0_bexp_t binary_2 = 0x80; // 0b10000000 + static constexpr e8m0_bexp_t binary_3 = 0x82; // 0b10000010 + static constexpr e8m0_bexp_t binary_135 = 0x87; // 0b10000111 + static constexpr e8m0_bexp_t binary_142 = 0x8E; // 0b10001110 + + __host__ __device__ static constexpr e8m0_bexp_t Min() { return e8m0_bexp_t(binary_min); } + __host__ __device__ static constexpr e8m0_bexp_t Max() { return e8m0_bexp_t(binary_max); } + __host__ __device__ static constexpr e8m0_bexp_t QuietNaN() { return e8m0_bexp_t(binary_qnan); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_1() { return e8m0_bexp_t(binary_1); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_2() { return e8m0_bexp_t(binary_2); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_3() { return e8m0_bexp_t(binary_3); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_135() + { + return e8m0_bexp_t(binary_135); + } + __host__ __device__ static constexpr e8m0_bexp_t Binary_142() + { + return e8m0_bexp_t(binary_142); + } }; template @@ -1766,6 +3092,7 @@ struct NumericUtils static constexpr uint32_t NegInf = 0xFF800000; static constexpr uint32_t NaN = 0x7F800001; static constexpr uint32_t Neg0 = 0x80000000; + static constexpr bool has_inf = true; using bitwise_type = uint32_t; }; @@ -1783,33 +3110,158 @@ struct NumericUtils static constexpr uint32_t NegInf = 0xFC00; static constexpr uint32_t NaN = 0x7C01; static constexpr uint32_t Neg0 = 0x8000; + static constexpr bool has_inf = true; using bitwise_type = uint16_t; }; template <> -struct NumericUtils +struct NumericUtils +{ + static constexpr int exp = 8; + static constexpr int mant = 7; + static constexpr int bias = 128; // negative zero nan mode + // static constexpr int bias = 127; // ieee mode +}; + +template <> +struct NumericUtils { static constexpr int exp = 4; static constexpr int mant = 3; static constexpr int bias = 8; // negative zero nan mode // static constexpr int bias = 7; // ieee mode + static constexpr bool has_inf = false; }; template <> -struct NumericUtils +struct NumericUtils { static constexpr int exp = 5; static constexpr int mant = 2; static constexpr int bias = 16; // negative zero nan mode // static constexpr int bias = 15; // ieee mode + static constexpr bool has_inf = false; +}; +template <> +struct NumericUtils +{ + static constexpr int exp = 4; + static constexpr int mant = 3; + static constexpr int bias = 7; }; template <> -struct NumericUtils +struct NumericUtils +{ + static constexpr int exp = 5; + static constexpr int mant = 2; + static constexpr int bias = 15; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 2; + static constexpr int mant = 1; + static constexpr int bias = 1; + static constexpr uint32_t sr_shift = 10; + + static constexpr int unbiased_exp_min = 0; + static constexpr int unbiased_exp_max = 2; + static constexpr int biased_exp_min = 1; + static constexpr int biased_exp_max = 3; + + static constexpr uint8_t positive_zero_mask = 0b0000; + static constexpr uint8_t negative_zero_mask = 0b1000; + + static constexpr uint8_t one_mask = 0b0010; + static constexpr uint8_t set_sign_mask = 0b0111; + + static constexpr uint8_t data_max_positive_normal_mask = 0b0111; + static constexpr uint8_t data_max_negative_normal_mask = 0b1111; + + static constexpr uint8_t data_max_positive_subnormal_mask = 0b0001; + static constexpr uint8_t data_max_negative_subnormal_mask = 0b1001; + + static constexpr bool has_inf = false; + + using bitwise_type = uint8_t; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 2; + static constexpr int mant = 3; + static constexpr int bias = 1; + static constexpr uint32_t sr_shift = 12; + + static constexpr int unbiased_exp_min = 0; + static constexpr int unbiased_exp_max = 2; + static constexpr int biased_exp_min = 1; + static constexpr int biased_exp_max = 3; + + static constexpr uint8_t positive_zero_mask = 0b000000; + static constexpr uint8_t negative_zero_mask = 0b100000; + + static constexpr uint8_t set_sign_mask = 0b011111; + + static constexpr uint8_t data_max_positive_normal_mask = 0b011111; + static constexpr uint8_t data_max_negative_normal_mask = 0b111111; + + static constexpr uint8_t data_max_positive_subnormal_mask = 0b000111; + static constexpr uint8_t data_max_negative_subnormal_mask = 0b100111; + + static constexpr bool has_inf = false; + static constexpr bool has_nan = false; + static constexpr bool has_zero = true; + + using bitwise_type = uint8_t; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 3; + static constexpr int mant = 2; + static constexpr int bias = 3; + static constexpr uint32_t sr_shift = 11; + + static constexpr int unbiased_exp_min = -2; + static constexpr int unbiased_exp_max = 4; + static constexpr int biased_exp_min = 1; + static constexpr int biased_exp_max = 7; + + static constexpr uint8_t positive_zero_mask = 0b000000; + static constexpr uint8_t negative_zero_mask = 0b100000; + + static constexpr uint8_t set_sign_mask = 0b011111; + + static constexpr uint8_t data_max_positive_normal_mask = 0b011111; + static constexpr uint8_t data_max_negative_normal_mask = 0b111111; + + static constexpr uint8_t data_max_positive_subnormal_mask = 0b000011; + static constexpr uint8_t data_max_negative_subnormal_mask = 0b100011; + + static constexpr bool has_inf = false; + static constexpr bool has_nan = false; + static constexpr bool has_zero = true; + + using bitwise_type = uint8_t; +}; + +template <> +struct NumericUtils { static constexpr int exp = 8; - static constexpr int mant = 7; - static constexpr int bias = 128; // negative zero nan mode - // static constexpr int bias = 127; // ieee mode + static constexpr int mant = 0; + static constexpr int bias = 127; + + static constexpr int unbiased_exp_min = -127; + static constexpr int unbiased_exp_max = 127; + static constexpr int biased_exp_min = 0; + static constexpr int biased_exp_max = 254; + + using bitwise_type = uint8_t; }; } // namespace ck diff --git a/include/ck/utility/debug.hpp b/include/ck/utility/debug.hpp index 03c4e16dd6e8afe9e954d1e67d55e21ddf3dbbe1..2b247cc02a001c4bb1797a3ef5a4386eaec3fc98 100644 --- a/include/ck/utility/debug.hpp +++ b/include/ck/utility/debug.hpp @@ -1,8 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef UTILITY_DEBUG_HPP #define UTILITY_DEBUG_HPP +#include "type.hpp" namespace ck { namespace debug { diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index 0dcc514a2f6548d6ca4ac5f8d8c89ee09775131c..6de17a61522a226b25c9430a672a2f06532ff60f 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -29,6 +29,13 @@ struct DynamicBuffer ElementSpaceSize element_space_size_; T invalid_element_value_ = T{0}; + static constexpr index_t PackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size) : p_data_{p_data}, element_space_size_{element_space_size} { @@ -54,7 +61,8 @@ struct DynamicBuffer template >::type, - typename scalar_type>::type>::value, + typename scalar_type>::type>::value || + !is_native_type(), bool>::type = false> __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const { @@ -81,14 +89,18 @@ struct DynamicBuffer return amd_buffer_load_invalid_element_return_zero, t_per_x, coherence>( - p_data_, i, is_valid_element, element_space_size_); + p_data_, i, is_valid_element, element_space_size_ / PackedSize); } else { return amd_buffer_load_invalid_element_return_customized_value, t_per_x, coherence>( - p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); + p_data_, + i, + is_valid_element, + element_space_size_ / PackedSize, + invalid_element_value_); } } else @@ -190,12 +202,13 @@ struct DynamicBuffer dst_buf.p_data_, dst_offset, is_valid_element, - element_space_size_); + element_space_size_ / PackedSize); } template >::type, - typename scalar_type>::type>::value, + typename scalar_type>::type>::value || + !is_native_type(), bool>::type = false> __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) { @@ -224,7 +237,7 @@ struct DynamicBuffer constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; amd_buffer_store, t_per_x, coherence>( - x, p_data_, i, is_valid_element, element_space_size_); + x, p_data_, i, is_valid_element, element_space_size_ / PackedSize); } else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds && is_same>::type, int8_t>::value && @@ -376,7 +389,7 @@ struct DynamicBuffer constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; amd_buffer_atomic_add, t_per_x>( - x, p_data_, i, is_valid_element, element_space_size_); + x, p_data_, i, is_valid_element, element_space_size_ / PackedSize); } else { @@ -415,7 +428,7 @@ struct DynamicBuffer constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; amd_buffer_atomic_max, t_per_x>( - x, p_data_, i, is_valid_element, element_space_size_); + x, p_data_, i, is_valid_element, element_space_size_ / PackedSize); } else if(is_valid_element) { diff --git a/include/ck/utility/e8m0.hpp b/include/ck/utility/e8m0.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a692f533f8b9d891d6cef99ff879e54b75337103 --- /dev/null +++ b/include/ck/utility/e8m0.hpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/type.hpp" + +namespace ck { + +/** + * @brief Unsigned representation of a conventional biased Float32 exponent. + * + * bias = 127; + * + * E8M0_1 = 0b01111111; => 2^(127-127) = 1 + * E8M0_2 = 0b10000000; => 2^(128-127) = 2^1 = 2 + * E8M0_3 = 0b10000010; => 2^(130-127) = 2^3 = 8 + * E8M0_135 = 0b10000111; => 2^(135-127) = 2^8 = 256 + * E8M0_142 = 0b10001110; => 2^(142-127) = 2^15 = 32768 + * E8M0_MIN = 0b00000000; => 2^-127 + * E8M0_MAX = 0b11111110; => 2^127 + * E8M0_NAN = 0b11111111; => NaN + */ +struct e8m0_bexp_t +{ + using type = uint8_t; + type data; + + constexpr static type bias = 127; + constexpr static type nan_mask = 0xFF; + + __host__ __device__ constexpr e8m0_bexp_t() : data{type{}} {} + __host__ __device__ constexpr e8m0_bexp_t(type init) : data{init} {} + __host__ __device__ constexpr e8m0_bexp_t(int init) : data{static_cast(init & nan_mask)} + { + } + __host__ __device__ explicit constexpr e8m0_bexp_t(float scale) + : data{static_cast((bit_cast(scale) & (nan_mask << 23)) >> 23)} + { + } + + __host__ __device__ explicit constexpr operator float() const + { + if(data == nan_mask || data == 0) + { + uint32_t bits = data << 1; + bits |= 1; + bits <<= 22; + return bit_cast(bits); + } + else + { + uint32_t bits = data << 23; + return bit_cast(bits); + } + } + + __host__ __device__ constexpr bool operator==(const e8m0_bexp_t& other) const + { + // strict IEEE compliance for NaN + return data == other.data && data != nan_mask; + } + + __host__ __device__ constexpr bool is_nan() const { return data == nan_mask; } +}; + +namespace utils { + +template +__host__ __device__ inline int get_exponent_value(T x); + +template <> +__host__ __device__ inline int get_exponent_value(e8m0_bexp_t x) +{ + return x.data; +} + +} // namespace utils + +} // namespace ck diff --git a/include/ck/utility/enable_if.hpp b/include/ck/utility/enable_if.hpp index c0a3c99f1fdafea9f151fe9fc319c2f7aaa0ffda..6ba63fc761c300c24e6013dce499f5d1c2ba9f27 100644 --- a/include/ck/utility/enable_if.hpp +++ b/include/ck/utility/enable_if.hpp @@ -1,14 +1,31 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once namespace ck { +#ifndef CK_CODE_GEN_RTC template using enable_if = std::enable_if; template using enable_if_t = typename std::enable_if::type; +#else +template +struct enable_if +{ +}; + +template +struct enable_if +{ + using type = T; +}; + +template +using enable_if_t = typename enable_if::type; +#endif + } // namespace ck diff --git a/include/ck/utility/env.hpp b/include/ck/utility/env.hpp index 6455402dcb331d91240ebe09d4d553f4d355f96e..809f302f743b1d9152afd98952010fddca92386a 100644 --- a/include/ck/utility/env.hpp +++ b/include/ck/utility/env.hpp @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#ifndef CK_CODE_GEN_RTC #pragma once #include @@ -183,3 +184,4 @@ void UpdateEnvVar(EnvVar, const std::string_view& val) } } // namespace ck +#endif diff --git a/include/ck/utility/functional.hpp b/include/ck/utility/functional.hpp index 91797d24092e3e32ad4a6bd40958952b124d9978..cd48ed17474480007f63180a7a25383172a3c8bd 100644 --- a/include/ck/utility/functional.hpp +++ b/include/ck/utility/functional.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -120,11 +120,11 @@ constexpr auto conditional_expr(X&& x, Y&& y) { if constexpr(predicate) { - return std::forward(x); + return ck::forward(x); } else { - return std::forward(y); + return ck::forward(y); } } diff --git a/include/ck/utility/functional4.hpp b/include/ck/utility/functional4.hpp index b5f3df8d7c517dfaf01320e41721da174883c2d9..8e86a296dc2ea0e1aca99a8480d5a826583ffd30 100644 --- a/include/ck/utility/functional4.hpp +++ b/include/ck/utility/functional4.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_FUNCTIONAL4_HPP #define CK_FUNCTIONAL4_HPP @@ -21,7 +21,7 @@ struct unpack_impl> template __host__ __device__ constexpr auto operator()(F&& f, X&& x) const { - return std::forward(f)(std::forward(x).At(Number{})...); + return ck::forward(f)(ck::forward(x).At(Number{})...); } }; @@ -35,8 +35,8 @@ struct unpack2_impl, Sequence> template __host__ __device__ constexpr auto operator()(F&& f, X&& x, Y&& y) const { - return std::forward(f)(std::forward(x).At(Number{})..., - std::forward(y).At(Number{})...); + return ck::forward(f)(ck::forward(x).At(Number{})..., + ck::forward(y).At(Number{})...); } }; @@ -47,7 +47,7 @@ __host__ __device__ constexpr auto unpack(F&& f, X&& x) { using X_ = remove_reference_t; return detail::unpack_impl::type>{}( - std::forward(f), std::forward(x)); + ck::forward(f), ck::forward(x)); } // TODO: properly implement unpack that takes any number of containers @@ -58,7 +58,7 @@ __host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y) using Y_ = remove_reference_t; return detail::unpack2_impl::type, typename arithmetic_sequence_gen<0, Y_::Size(), 1>::type>{}( - std::forward(f), std::forward(x), std::forward(y)); + ck::forward(f), ck::forward(x), ck::forward(y)); } } // namespace ck diff --git a/include/ck/utility/integral_constant.hpp b/include/ck/utility/integral_constant.hpp index 376070eb3d8ac326603b71e52e76949c168f4219..75f35d762c52d465b39e718598082b0360905a5b 100644 --- a/include/ck/utility/integral_constant.hpp +++ b/include/ck/utility/integral_constant.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -48,4 +48,9 @@ __host__ __device__ constexpr auto operator%(integral_constant, integral_ return integral_constant{}; } +template +using bool_constant = integral_constant; + +using true_type = bool_constant; +using false_type = bool_constant; } // namespace ck diff --git a/include/ck/utility/is_detected.hpp b/include/ck/utility/is_detected.hpp index 7a324a6c458b3f1b8bb8037ccfac76e5eadceee0..a700fcfff1dd21de1a5784e0c74132f7812ab7fa 100644 --- a/include/ck/utility/is_detected.hpp +++ b/include/ck/utility/is_detected.hpp @@ -1,22 +1,24 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/utility/integral_constant.hpp" + namespace ck { namespace detail { template class Op, class... Args> struct detector { - using value_t = std::false_type; + using value_t = integral_constant; using type = Default; }; template class Op, class... Args> -struct detector>, Op, Args...> +struct detector>, Op, Args...> { - using value_t = std::true_type; + using value_t = integral_constant; using type = Op; }; } // namespace detail @@ -32,12 +34,12 @@ template