diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000000000000000000000000000000000000..276690bd4f96f50eb6a5121d237c8c9bcb80224d --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,12 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates + +version: 2 +updates: + - package-ecosystem: "pip" # See documentation for possible values + directory: "/docs/sphinx" # Location of package manifests + open-pull-requests-limit: 10 + schedule: + interval: "daily" diff --git a/.gitignore b/.gitignore index 5667695bb5500d44aac0a9aa3ad24e0f9b458c32..7af066c82dcef3cf4cc846168567e90e0a0c9db3 100644 --- a/.gitignore +++ b/.gitignore @@ -48,6 +48,11 @@ build* .gdb_history install.dir* -# directories containing generated documentation -docs/source/_build/ -docs/docBin/ +# documentation artifacts +_build/ +_images/ +_static/ +_templates/ +_toc.yml +docBin/ +_doxygen/ diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5f50df2525d2948bbabe36184f23310d47339d8e --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,18 @@ +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +version: 2 + +build: + os: ubuntu-22.04 + tools: + python: "3.8" + +sphinx: + configuration: docs/conf.py + +formats: [htmlzip, pdf, epub] + +python: + install: + - requirements: docs/sphinx/requirements.txt diff --git a/CHANGELOG.md b/CHANGELOG.md index 01883500465583097aa68642e89a89a06b6750b2..3898b5ce2d9c1a4006209a1aa9fa6e48302ea30a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ Full documentation for Composable Kernel is not yet available. - Added multi-embeddings support (#542). - Added Navi3x blockwise GEMM and real GEMM support (#541). - Added Navi grouped ConvBwdWeight support (#505). +- Added pool3d forward (#697). +- Added maxpool backward (#750). ### Changed - Changed ... diff --git a/CMakeLists.txt b/CMakeLists.txt index f861e302039de8e2c1ae17bf7cb39361aff8535f..c9fb6b4552c81ef29e77dcde807f552b42916572 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,6 +22,7 @@ include(TargetFlags) list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip) option(USE_BITINT_EXTENSION_INT4, "Whether to enable clang's BitInt extension to provide int4 data type." OFF) +option(USE_OPT_NAVI3X, "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF) if(USE_BITINT_EXTENSION_INT4) add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) @@ -29,6 +30,12 @@ if(USE_BITINT_EXTENSION_INT4) message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}") endif() +if(USE_OPT_NAVI3X) + add_compile_options(-mcumode) + add_compile_options(-mno-wavefrontsize64) + message("CK compiled with USE_OPT_NAVI3X set to ${USE_OPT_NAVI3X}") +endif() + ## Threads set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 8ccfe99c3cc73b643f8b92cb654005e54c0774bd..07d83688176bd771142a8698fda492c538c691b1 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -4,7 +4,7 @@ This is the list of developers and contributors to Composable Kernel library ## Developers -[Chao Liu](https://github.com/asroy), [Jing Zhang](https://github.com/zjing14), 2018-2022 +[Chao Liu](https://github.com/asroy), [Jing Zhang](https://github.com/zjing14), 2018-2023 [Letao Qin](https://github.com/ltqin), [Qianfeng Zhang](https://github.com/qianfengz), [Liang Huang](https://github.com/carlushuang), [Shaojie Wang](https://github.com/shaojiewang), 2019-2022 diff --git a/Dockerfile b/Dockerfile index b03cb836ad60c876fe4ce70078d53ae6266f8708..bc477680604a41faf525a8ca191508161ff96c21 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,30 +1,41 @@ FROM ubuntu:20.04 - -ARG ROCMVERSION=5.3 -ARG compiler_version="release" +ARG DEBIAN_FRONTEND=noninteractive +ARG ROCMVERSION=5.6 +ARG compiler_version="" ARG compiler_commit="" RUN set -xe ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/ RUN useradd -rm -d /home/jenkins -s /bin/bash -u 1004 jenkins -RUN useradd -rm -d /home/manitera -s /bin/bash -u 1002 manitera # Add rocm repository +RUN chmod 1777 /tmp RUN apt-get update -RUN apt-get install -y wget gnupg -RUN wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - -RUN sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO ubuntu main > /etc/apt/sources.list.d/rocm.list" +RUN apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl +RUN if [ "$ROCMVERSION" != "5.6" ]; then \ + wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ + sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO ubuntu main > /etc/apt/sources.list.d/rocm.list"; \ + elif [ "$ROCMVERSION" = "5.6" ] && [ "$compiler_version" = "" ]; then \ + sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amd-nonfree-radeon_20.04-1_all.deb" && \ + apt update && apt-get install -y ./amd-nonfree-radeon_20.04-1_all.deb && \ + amdgpu-repo --amdgpu-build=1567752 --rocm-build=compute-rocm-dkms-no-npi-hipclang/11914 && \ + amdgpu-install -y --usecase=rocm --no-dkms; \ + elif [ "$ROCMVERSION" = "5.6" ] && [ "$compiler_version" = "rc3" ] || [ "$compiler_version" = "amd-stg-open" ]; then \ + sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_5.6-20.04-1_all.deb" && \ + apt update && apt-get install -y ./amdgpu-install-internal_5.6-20.04-1_all.deb && \ + sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 5.6 rel-45 > /etc/apt/sources.list.d/rocm-build.list' && \ + amdgpu-repo --amdgpu-build=1602498 && amdgpu-install -y --usecase=rocm --no-dkms; \ + fi + RUN wget --no-check-certificate -qO - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | apt-key add - RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" +RUN curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg # Install dependencies RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ - apt-utils \ build-essential \ ccache \ - cmake-data \ cmake \ - curl \ git \ hip-rocclr \ jq \ @@ -34,17 +45,13 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- libpthread-stubs0-dev \ llvm-amdgpu \ pkg-config \ - python \ python3 \ - python-dev \ python3-dev \ python3-pip \ sshpass \ software-properties-common \ - rocm-dev \ - rocm-device-libs \ - rocm-cmake \ vim \ + nano \ zlib1g-dev \ openssh-server \ clang-format-10 \ @@ -52,6 +59,17 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- apt-get clean && \ rm -rf /var/lib/apt/lists/* +#Install latest version of cmake +RUN apt purge --auto-remove -y cmake +RUN apt update +RUN apt install -y software-properties-common lsb-release +RUN apt clean all +RUN wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | gpg --dearmor - | tee /etc/apt/trusted.gpg.d/kitware.gpg >/dev/null +RUN apt-add-repository "deb https://apt.kitware.com/ubuntu/ $(lsb_release -cs) main" +RUN apt install -y kitware-archive-keyring +RUN rm /etc/apt/trusted.gpg.d/kitware.gpg +RUN apt install -y cmake + # Setup ubsan environment to printstacktrace RUN ln -s /usr/bin/llvm-symbolizer-3.8 /usr/local/bin/llvm-symbolizer ENV UBSAN_OPTIONS=print_stacktrace=1 @@ -87,12 +105,7 @@ ENV compiler_commit=$compiler_commit RUN sh -c "echo compiler version = '$compiler_version'" RUN sh -c "echo compiler commit = '$compiler_commit'" -RUN --mount=type=ssh if [ "$compiler_version" = "amd-stg-open" ]; then \ - sed -i '/$HIP_CLANG_TARGET = chomp($HIP_CLANG_TARGET);/c\ chomp($HIP_CLANG_TARGET);' /opt/rocm/hip/bin/hipcc.pl && \ - sed -i '/$HIP_CLANG_TARGET = chomp($HIP_CLANG_TARGET);/c\ chomp($HIP_CLANG_TARGET);' /opt/rocm/bin/hipcc.pl; \ - fi - -RUN --mount=type=ssh if [ "$compiler_version" != "release" ] && [ "$compiler_commit" = "" ]; then \ +RUN if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" = "" ]; then \ git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ cd llvm-project && mkdir build && cd build && \ cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld;compiler-rt" ../llvm && \ @@ -100,7 +113,7 @@ RUN --mount=type=ssh if [ "$compiler_version" != "release" ] && [ "$compiler_com else echo "using the release compiler"; \ fi -RUN --mount=type=ssh if [ "$compiler_version" != "release" ] && [ "$compiler_commit" != "" ]; then \ +RUN if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" != "" ]; then \ git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \ cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld;compiler-rt" ../llvm && \ diff --git a/Jenkinsfile b/Jenkinsfile index bb0b352d757d10af7b678536ad6d01f00300a797..fbff349fc3aaad004d81483fef6952d975660fdd 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -19,12 +19,33 @@ def runShell(String command){ def getDockerImageName(){ def img - if (params.COMPILER_COMMIT == ""){ - img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}" + if (params.ROCMVERSION != "5.6"){ + if (params.COMPILER_VERSION == "") { + img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}" + } + else{ + if (params.COMPILER_COMMIT == ""){ + img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}" + } + else{ + def commit = "${params.COMPILER_COMMIT}"[0..6] + img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}_${commit}" + } + } } else{ - def commit = "${params.COMPILER_COMMIT}"[0..6] - img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}_${commit}" + if (params.COMPILER_VERSION == "") { + img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub20.04_rocm${params.ROCMVERSION}" + } + else{ + if (params.COMPILER_COMMIT == ""){ + img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}" + } + else{ + def commit = "${params.COMPILER_COMMIT}"[0..6] + img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}_${commit}" + } + } } return img } @@ -49,11 +70,11 @@ def build_compiler(){ compiler = '/opt/rocm/bin/hipcc' } else{ - if (params.COMPILER_VERSION == "release"){ - compiler = "/opt/rocm/llvm/bin/clang++" + if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){ + compiler = "/llvm-project/build/bin/clang++" } else{ - compiler = "/llvm-project/build/bin/clang++" + compiler = "/opt/rocm/llvm/bin/clang++" } } return compiler @@ -232,7 +253,7 @@ def buildHipClangJob(Map conf=[:]){ dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION != "release"){ + if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } @@ -287,7 +308,7 @@ def runCKProfiler(Map conf=[:]){ dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION != "release"){ + if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } @@ -420,7 +441,7 @@ def Build_CK(Map conf=[:]){ dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION != "release"){ + if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } @@ -472,10 +493,11 @@ def Build_CK(Map conf=[:]){ { cmake_build(conf) dir("build"){ + //run tests and examples + sh 'make -j\$(( \$(nproc) / 2 )) check' if (navi_node == 0 ){ - //run tests and examples on all nodes except Navi - sh 'make -j check' - //we only need the ckProfiler to run the performance tests, so we pack and stash it + //we only need the ckProfiler to run the performance tests, so we pack and stash it + //do not stash profiler on Navi nodes sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler' stash "ckProfiler.tar.gz" } @@ -576,7 +598,7 @@ def process_results(Map conf=[:]){ //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true - 0 21 * * * % COMPILER_VERSION=release;COMPILER_COMMIT= + 0 21 * * * % ROCMVERSION=5.5;COMPILER_VERSION=release;COMPILER_COMMIT= 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=''' : "" pipeline { @@ -594,16 +616,16 @@ pipeline { description: "Force building docker image (default: false), set to true if docker image needs to be updated.") string( name: 'ROCMVERSION', - defaultValue: '5.4.3', - description: 'Specify which ROCM version to use: 5.4.3 (default).') + defaultValue: '5.6', + description: 'Specify which ROCM version to use: 5.6 (default).') string( name: 'COMPILER_VERSION', - defaultValue: 'amd-stg-open', - description: 'Specify which version of compiler to use: ck-9110, release, or amd-stg-open (default).') + defaultValue: '', + description: 'Specify which version of compiler to use: release, amd-stg-open, or leave blank (default).') string( name: 'COMPILER_COMMIT', - defaultValue: '5541927df00eabd6a110180170eca7785d436ee3', - description: 'Specify which commit of compiler branch to use: leave empty to use the latest commit, or use 5541927df00eabd6a110180170eca7785d436ee3 (default) commit of amd-stg-open branch.') + defaultValue: '', + description: 'Specify which commit of compiler branch to use: leave blank to use the latest commit, or use 5541927df00eabd6a110180170eca7785d436ee3 (default) commit of amd-stg-open branch.') string( name: 'BUILD_COMPILER', defaultValue: 'hipcc', @@ -665,12 +687,31 @@ pipeline { { parallel { + stage("Build CK and run Tests on MI100/MI200/MI300") + { + when { + beforeAgent true + expression { params.RUN_FULL_QA.toBoolean() } + } + agent{ label rocmnode("gfx908 || gfx90a") } + environment{ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a;gfx940" """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a;gfx940" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ + } + steps{ + Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + } + } stage("Build CK and run Tests on MI100/MI200") { + when { + beforeAgent true + expression { !params.RUN_FULL_QA.toBoolean() } + } agent{ label rocmnode("gfx908 || gfx90a") } environment{ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908,gfx90a" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') @@ -684,8 +725,8 @@ pipeline { } agent{ label rocmnode("navi21") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1030;gfx1100;gfx1101;gfx1102" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1030" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ } steps{ diff --git a/LICENSE b/LICENSE index 2fe9a8455efaeda2eab474b2aa038ec2d9e76841..e03fddaf78080705d26ec277629cfb8010c077bf 100644 --- a/LICENSE +++ b/LICENSE @@ -7,7 +7,7 @@ Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou) Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan) SPDX-License-Identifier: MIT -Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 151da974a559bef2e52971034f607d4afe44f632..a45f61a37df2b82805a9a0643abc97821803ad5e 100644 --- a/README.md +++ b/README.md @@ -1,43 +1,60 @@ # Composable Kernel ## Methodology + Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel languages, like HIP C++. CK utilizes two concepts to achieve performance portability and code maintainability: * A tile-based programming model * Algorithm complexity reduction for complex ML operators, using innovative technique we call "Tensor Coordinate Transformation". -![ALT](/doc/image/ck_component.png "CK Components") +![ALT](/docs/data/ck_component.png "CK Components") ## Code Structure + Current CK library are structured into 4 layers: * "Templated Tile Operators" layer * "Templated Kernel and Invoker" layer * "Instantiated Kernel and Invoker" layer * "Client API" layer -![ALT](/doc/image/ck_layer.png "CK Layers") +![ALT](/docs/data/ck_layer.png "CK Layers") + +## Documentation + +Run the steps below to build documentation locally. + +``` +cd docs +pip3 install -r sphinx/requirements.txt +python3 -m sphinx -T -E -b html -d _build/doctrees -D language=en . _build/html +``` ## Contributors + The list of developers and contributors is here: [Contributors](/CONTRIBUTORS.md) ## Citation + If you use CK, please use following citations: * CK paper will be freely available on arXiv soon: [Realizing Tensor Operators Using Coordinate Transformations and Tile Based Programming](???) * [CITATION.cff](/CITATION.cff) ## License + CK is released under the MIT license. [License File](/LICENSE) # Build CK ## Build docker image + ```bash DOCKER_BUILDKIT=1 docker build -t ck:latest -f Dockerfile . ``` ## Launch docker + ```bash docker run \ -it \ @@ -50,10 +67,12 @@ ck:latest \ ``` ## Build CK + ```bash mkdir build && cd build # Need to specify target ID, example below is for gfx908 and gfx90a + cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ @@ -64,6 +83,7 @@ cmake ``` ### Build examples and tests + ```bash make -j examples tests make test @@ -73,21 +93,25 @@ Instructions for running each individual examples are under [example](/example) ## Build ckProfiler + ```bash make -j ckProfiler ``` Instructions for running ckProfiler are under [profiler](/profiler) ## Install CK + ```bash make install ``` ## Using CK as pre-built kernel library + Instructions for using CK as a pre-built kernel library are under [client_example](/client_example) ## Caveat ### Kernel Timing and Verification + CK's own kernel timer will warn up kernel once, and then run it multiple times to get average kernel time. For some kernels that use atomic add, this will cause output buffer to be accumulated multiple times, causing verification failure. diff --git a/client_example/01_gemm/gemm.cpp b/client_example/01_gemm/gemm.cpp index ba7118ba3929e3e3bcdf02e40044748860bfeebe..c37f208db1cee9bfff1fff469fd79d059fb179f0 100644 --- a/client_example/01_gemm/gemm.cpp +++ b/client_example/01_gemm/gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt index b7b724ccc484d6c9ea83e791aa5b8fcde420eef7..ba2952022233b3ffd21fc245b8fd199a5ec5b096 100644 --- a/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt +++ b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt @@ -11,3 +11,17 @@ target_link_libraries(client_gemm_fastgelu PRIVATE composable_kernel::device_ope add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu client_gemm_fastgelu) + +add_custom_target(client_gemm_fastgelu_generic_examples) + +add_executable(client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp) +target_link_libraries(client_gemm_add_add_fastgelu_generic PRIVATE composable_kernel::device_operations) + +add_executable(client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp) +target_link_libraries(client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_operations) + +add_executable(client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp) +target_link_libraries(client_gemm_fastgelu_generic PRIVATE composable_kernel::device_operations) + +add_dependencies(client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic + client_gemm_add_fastgelu_generic client_gemm_fastgelu_generic) diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp index 08f297f58a8d522aec8c991a8d38484ccf8e420c..756889562e84c66efb5f972621bcb61edda3af82 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu_generic.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu_generic.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2ed942f0adf64df90ea1046d1de191ce2247d7c4 --- /dev/null +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu_generic.cpp @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAddFastGelu; + +using ADataType = F16; +using BDataType = F16; +using D0DataType = F16; +using D1DataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using D1Layout = Row; +using ELayout = Row; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD0 = 0; + ck::index_t StrideD1 = 4096; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 9) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideD0 = std::stoi(argv[6]); + StrideD1 = std::stoi(argv[7]); + StrideE = std::stoi(argv[8]); + } + else + { + printf("arg1 to 8: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem d0_m_n_device_buf(sizeof(D0DataType) * + f_matrix_space_size(M, N, StrideD0, D0Layout{})); + SimpleDeviceMem d1_m_n_device_buf(sizeof(D1DataType) * + f_matrix_space_size(M, N, StrideD1, D1Layout{})); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{})); + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddAddFastGelu>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + // get generic instance + auto& op_ptr = op_ptrs[0]; + + std::cout << "Run the generic instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + // run the generic instance + auto argument_ptr = + op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer(), + d1_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0, StrideD1}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + else + { + throw std::runtime_error( + "Generic instance should be suitable for various input lengths/strides"); + } + + std::cout << "Done" << std::endl; + + return 0; +} diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp index 658c1e9e8fcbeab1ec7be4115d5a7d62c5e77283..8d2a8c234aae63a3566478b5aa9588389a247d4e 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -76,7 +76,7 @@ int main(int argc, char* argv[]) StrideA = std::stoi(argv[4]); StrideB = std::stoi(argv[5]); StrideD0 = std::stoi(argv[6]); - StrideE = std::stoi(argv[8]); + StrideE = std::stoi(argv[7]); } else { diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu_generic.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu_generic.cpp new file mode 100644 index 0000000000000000000000000000000000000000..644b428fc9f51b28a0f81ed7d79ac741ca6fbcbe --- /dev/null +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu_generic.cpp @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddFastGelu; + +using ADataType = F16; +using BDataType = F16; +using D0DataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using ELayout = Row; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD0 = 0; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 8) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideD0 = std::stoi(argv[6]); + StrideE = std::stoi(argv[7]); + } + else + { + printf("arg1 to 7: M, N, K, StrideA, StrideB, StrideD0, StrideE\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem d0_m_n_device_buf(sizeof(D0DataType) * + f_matrix_space_size(M, N, StrideD0, D0Layout{})); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{})); + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddFastGelu>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + // get generic instance + auto& op_ptr = op_ptrs[0]; + + std::cout << "Run the generic instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + // run the generic instance + auto argument_ptr = + op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + else + { + throw std::runtime_error( + "Generic instance should be suitable for various input lengths/strides"); + } + + std::cout << "Done" << std::endl; + + return 0; +} diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp index ea269545a5cc576f7cf98ea8496acf713b6aaea6..c02df018fd35c6d37a2c6f9b9fded6390c9afb19 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -72,7 +72,7 @@ int main(int argc, char* argv[]) StrideA = std::stoi(argv[4]); StrideB = std::stoi(argv[5]); - StrideE = std::stoi(argv[8]); + StrideE = std::stoi(argv[6]); } else { diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu_generic.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu_generic.cpp new file mode 100644 index 0000000000000000000000000000000000000000..482e93b421f7700cdf37ee014cf407ca3c63555b --- /dev/null +++ b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu_generic.cpp @@ -0,0 +1,162 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = FastGelu; + +using ADataType = F16; +using BDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using ELayout = Row; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 7) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideE = std::stoi(argv[6]); + } + else + { + printf("arg1 to 6: M, N, K, StrideA, StrideB, StrideE\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{})); + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + ck::Tuple<>, + ELayout, + ADataType, + BDataType, + ck::Tuple<>, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::FastGelu>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + // get generic instance + auto& op_ptr = op_ptrs[0]; + + std::cout << "Run the generic instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + // run the generic instance + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + {}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + else + { + throw std::runtime_error( + "Generic instance should be suitable for various input lengths/strides"); + } + + std::cout << "Done" << std::endl; + + return 0; +} diff --git a/client_example/03_gemm_layernorm/gemm_add_add_layernorm_naive.cpp b/client_example/03_gemm_layernorm/gemm_add_add_layernorm_naive.cpp index caa6573788d201b9fda605e5388423d5966d41ec..58c91f903bc7e2f3f296b06c746b1629513bc7f2 100644 --- a/client_example/03_gemm_layernorm/gemm_add_add_layernorm_naive.cpp +++ b/client_example/03_gemm_layernorm/gemm_add_add_layernorm_naive.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -172,18 +172,19 @@ int main() BLayout, CLayout>(); - const auto normalize_ptrs = - ck::tensor_operation::device::instance::get_device_normalize_from_mean_meansquare_instances< - CDataType, - ReduceDataType, - ReduceDataType, - GammaDataType, - BetaDataType, - LayerNormOutDataType>(); - std::cout << "found " << gemm_reduce_ptrs.size() << " gemm_reduceMean_reduceSquareMean instances" << std::endl; + using NormalizeDeviceOp = ck::tensor_operation::device::DeviceElementwise< + ck::Tuple, + ck::Tuple, + ck::tensor_operation::element_wise::Normalize, + 2>; + + const auto normalize_ptrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + NormalizeDeviceOp>::GetInstances(); + std::cout << "found " << normalize_ptrs.size() << " normalize instances" << std::endl; auto f_matrix_space_size = diff --git a/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp b/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp index d4f0c2048ba81c3a1f98a22c5eb57654b0498806..3d5fb6004844af269f7786ae05de8c20cc720633 100644 --- a/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp +++ b/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/04_contraction/CMakeLists.txt b/client_example/04_contraction/CMakeLists.txt index 971d5d9f1c00e471adb4e9a60c0331e602d61468..7ffedfeef36839c1c8bb942ffae11e5f6dcadf22 100644 --- a/client_example/04_contraction/CMakeLists.txt +++ b/client_example/04_contraction/CMakeLists.txt @@ -1,8 +1,14 @@ -add_executable(client_contraction_scale contraction_scale.cpp) -target_link_libraries(client_contraction_scale PRIVATE composable_kernel::device_operations) +add_executable(client_contraction_scale_fp32 contraction_scale_fp32.cpp) +target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_operations) -add_executable(client_contraction_bilinear contraction_bilinear.cpp) -target_link_libraries(client_contraction_bilinear PRIVATE composable_kernel::device_operations) +add_executable(client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp) +target_link_libraries(client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_operations) + +add_executable(client_contraction_scale_fp64 contraction_scale_fp64.cpp) +target_link_libraries(client_contraction_scale_fp64 PRIVATE composable_kernel::device_operations) + +add_executable(client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp) +target_link_libraries(client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_operations) add_executable(contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp) target_link_libraries(contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_operations) diff --git a/client_example/04_contraction/contraction_bilinear.cpp b/client_example/04_contraction/contraction_bilinear_fp32.cpp similarity index 99% rename from client_example/04_contraction/contraction_bilinear.cpp rename to client_example/04_contraction/contraction_bilinear_fp32.cpp index 91dead41a4cac19db857b99a233839e9e6647c57..89f834b9824e134f8f0aeed8aa54f78a5c8824a3 100644 --- a/client_example/04_contraction/contraction_bilinear.cpp +++ b/client_example/04_contraction/contraction_bilinear_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/04_contraction/contraction_bilinear_fp64.cpp b/client_example/04_contraction/contraction_bilinear_fp64.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1aa3ba7de597a0b97a295b0f7ee7ad21b1e9cd80 --- /dev/null +++ b/client_example/04_contraction/contraction_bilinear_fp64.cpp @@ -0,0 +1,281 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp" +#include "ck/library/utility/numeric.hpp" + +using F64 = double; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Bilinear; + +using ADataType = F64; +using BDataType = F64; +using AccDataType = F64; +using CShuffleDataType = F64; +using DDataType = F64; +using DsDataType = ck::Tuple; +using EDataType = F64; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ +// kknn +#if 1 + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; +// knnn +#elif 0 + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{64, 1, 131072, 2048}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; +// mknn +#elif 0 + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{128, 1, 245760, 3840}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; +// mnnn +#elif 0 + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{128, 1, 245760, 3840}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{64, 1, 131072, 2048}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; +#endif + + float alpha = 1.f; + float beta = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 25) + { + const ck::index_t M0 = std::stoi(argv[1]); + const ck::index_t M1 = std::stoi(argv[2]); + + const ck::index_t N0 = std::stoi(argv[3]); + const ck::index_t N1 = std::stoi(argv[4]); + + const ck::index_t K0 = std::stoi(argv[5]); + const ck::index_t K1 = std::stoi(argv[6]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[7]), std::stoi(argv[8]), std::stoi(argv[9]), std::stoi(argv[10])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13]), std::stoi(argv[14])}; + + d_ms_ns_lengths = {M0, M1, N0, N1}; + d_ms_ns_strides = { + std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17]), std::stoi(argv[18])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21]), std::stoi(argv[22])}; + + alpha = std::stof(argv[23]); + beta = std::stof(argv[24]); + } + else + { + printf("arg1 to 6: M0, M1, N0, N1, K0, K1\n"); + printf("arg7 to 10: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg11 to 14: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg15 to 18: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n"); + printf("arg19 to 22: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg23 to 24: alpha, beta\n"); + exit(0); + } + + auto f_tensor_space_size = [](auto lengths, auto strides) { + std::size_t space_size = 1; + for(std::size_t i = 0; i < lengths.size(); ++i) + { + space_size += (lengths[i] - 1) * strides[i]; + } + return space_size; + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * + f_tensor_space_size(a_ms_ks_lengths, a_ms_ks_strides)); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * + f_tensor_space_size(b_ns_ks_lengths, b_ns_ks_strides)); + SimpleDeviceMem d_device_buf(sizeof(DDataType) * + f_tensor_space_size(d_ms_ns_lengths, d_ms_ns_strides)); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * + f_tensor_space_size(e_ms_ns_lengths, e_ms_ns_strides)); + + using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD< + NumDimM, + NumDimN, + NumDimK, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Bilinear>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{alpha, beta}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = + op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + ck::index_t M = ck::accumulate_n( + e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(DDataType) * M * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/client_example/04_contraction/contraction_g1m2n3k1_add_xdl_fp16.cpp b/client_example/04_contraction/contraction_g1m2n3k1_add_xdl_fp16.cpp index 62be3377a2fb18bf388b857ecb758c0b7987871c..f8ea2258c2ba262e9db42f1b4dc92ff16cdc6286 100644 --- a/client_example/04_contraction/contraction_g1m2n3k1_add_xdl_fp16.cpp +++ b/client_example/04_contraction/contraction_g1m2n3k1_add_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/04_contraction/contraction_scale.cpp b/client_example/04_contraction/contraction_scale_fp32.cpp similarity index 99% rename from client_example/04_contraction/contraction_scale.cpp rename to client_example/04_contraction/contraction_scale_fp32.cpp index 4e08ee19cdb098b2dfb70a662d59c87008400123..ba7b0633c33aabeeb06547edfb08f506e637e599 100644 --- a/client_example/04_contraction/contraction_scale.cpp +++ b/client_example/04_contraction/contraction_scale_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/04_contraction/contraction_scale_fp64.cpp b/client_example/04_contraction/contraction_scale_fp64.cpp new file mode 100644 index 0000000000000000000000000000000000000000..24e52eb5aa423339ff96ad0914dc479d715fe7b7 --- /dev/null +++ b/client_example/04_contraction/contraction_scale_fp64.cpp @@ -0,0 +1,270 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/contraction_scale.hpp" +#include "ck/library/utility/numeric.hpp" + +using F64 = double; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Scale; + +using ADataType = F64; +using BDataType = F64; +using AccDataType = F64; +using CShuffleDataType = F64; +using DsDataType = ck::Tuple<>; +using EDataType = F64; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ +// kkn +#if 1 + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; +// knn +#elif 0 + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{64, 1, 131072, 2048}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; +// mkn +#elif 0 + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{128, 1, 245760, 3840}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; +// mnn +#elif 0 + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{128, 1, 245760, 3840}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{64, 1, 131072, 2048}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; +#endif + + float scale = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 20) + { + const ck::index_t M0 = std::stoi(argv[1]); + const ck::index_t M1 = std::stoi(argv[2]); + + const ck::index_t N0 = std::stoi(argv[3]); + const ck::index_t N1 = std::stoi(argv[4]); + + const ck::index_t K0 = std::stoi(argv[5]); + const ck::index_t K1 = std::stoi(argv[6]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[7]), std::stoi(argv[8]), std::stoi(argv[9]), std::stoi(argv[10])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13]), std::stoi(argv[14])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17]), std::stoi(argv[18])}; + + scale = std::stof(argv[19]); + } + else + { + printf("arg1 to 6: M0, M1, N0, N1, K0, K1\n"); + printf("arg7 to 10: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg11 to 14: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg15 to 18: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg19: scale\n"); + exit(0); + } + + auto f_tensor_space_size = [](auto lengths, auto strides) { + std::size_t space_size = 1; + for(std::size_t i = 0; i < lengths.size(); ++i) + { + space_size += (lengths[i] - 1) * strides[i]; + } + return space_size; + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * + f_tensor_space_size(a_ms_ks_lengths, a_ms_ks_strides)); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * + f_tensor_space_size(b_ns_ks_lengths, b_ns_ks_strides)); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * + f_tensor_space_size(e_ms_ns_lengths, e_ms_ns_strides)); + + using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD< + NumDimM, + NumDimN, + NumDimK, + ADataType, + BDataType, + ck::Tuple<>, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Scale>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{scale}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 0>{}, + std::array, 0>{}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + ck::index_t M = ck::accumulate_n( + e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/client_example/05_layernorm/layernorm2d.cpp b/client_example/05_layernorm/layernorm2d.cpp index 856a4cc21935f094bfd7040ea095fb04d78b7eb4..4af4d7abe8ee6c5120f6060c78141021052cb619 100644 --- a/client_example/05_layernorm/layernorm2d.cpp +++ b/client_example/05_layernorm/layernorm2d.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/06_softmax/softmax4d.cpp b/client_example/06_softmax/softmax4d.cpp index e939ce8dfedb10166b15388cae1d659d33e2154a..2ccad27a88757b576a050d0558acc4108857cbd1 100644 --- a/client_example/06_softmax/softmax4d.cpp +++ b/client_example/06_softmax/softmax4d.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -53,12 +53,35 @@ int main(int argc, char* argv[]) SimpleDeviceMem in(sizeof(InDataType) * num_elements); SimpleDeviceMem out(sizeof(OutDataType) * num_elements); - using DeviceOp = ck::tensor_operation::device:: - DeviceSoftmax; + using DeviceOp = ck::tensor_operation::device::DeviceSoftmax; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); + auto& generic_op_ptr = op_ptrs[0]; + + auto generic_argument_ptr = generic_op_ptr->MakeArgumentPointer(in_lengths, + in_strides, + reduce_dims, + alpha, + beta, + in.GetDeviceBuffer(), + out.GetDeviceBuffer(), + PassThrough{}, + PassThrough{}); + + if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get())) + { + throw std::runtime_error( + "The generic kernel instance should be able to support any input shapes"); + }; + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; std::string best_op_name; @@ -74,11 +97,6 @@ int main(int argc, char* argv[]) { auto& op_ptr = op_ptrs[i]; - if(op_ptr->GetRank() != Rank || op_ptr->GetNumReduceDim() != NumReduceDim) - { - continue; - } - auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths, in_strides, reduce_dims, diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp index 9fbdb83b1cf10082958730e5115f3bf54e12d415..70be0101c6d92512ab28a72797d2b8c46fb55281 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp index ece6e30c56004f58ff136502b4d8874e8269cd3a..57a210fa1f5f0d3f0eca014e08ccd119d29e4fd6 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -17,22 +17,22 @@ using InDataType = ck::half_t; using WeiDataType = ck::half_t; using OutDataType = ck::half_t; -using InLayout = ck::tensor_layout::convolution::GNHWC; +using InLayout = ck::tensor_layout::convolution::NHWGC; using WeiLayout = ck::tensor_layout::convolution::GKYXC; -using OutLayout = ck::tensor_layout::convolution::GNHWK; +using OutLayout = ck::tensor_layout::convolution::NHWGK; using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr ck::index_t NumDimSpatial = 2; static constexpr ck::index_t G = 32; -static constexpr ck::index_t N = 256; -static constexpr ck::index_t K = 192; -static constexpr ck::index_t C = 192; -static constexpr ck::index_t Y = 3; -static constexpr ck::index_t X = 3; -static constexpr ck::index_t Hi = 28; -static constexpr ck::index_t Wi = 28; -static constexpr ck::index_t Ho = 28; -static constexpr ck::index_t Wo = 28; +static constexpr ck::index_t N = 256; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 28; // input H +static constexpr ck::index_t Wi = 28; // input W +static constexpr ck::index_t Ho = 28; // output H +static constexpr ck::index_t Wo = 28; // output W struct SimpleDeviceMem { @@ -52,50 +52,24 @@ struct SimpleDeviceMem int main() { - std::array in_lengths{G, N, Hi, Wi, C}; - std::array in_strides{0, 0, 0, 0, 1}; - - std::array wei_lengths{G, K, Y, X, C}; - std::array wei_strides{0, 0, 0, 0, 1}; - - std::array out_lengths{G, N, Ho, Wo, K}; - std::array out_strides{0, 0, 0, 0, 1}; - - std::partial_sum(rbegin(in_lengths), - std::prev(rend(in_lengths)), - std::next(rbegin(in_strides)), - std::multiplies<>{}); - std::partial_sum(rbegin(wei_lengths), - std::prev(rend(wei_lengths)), - std::next(rbegin(wei_strides)), - std::multiplies<>{}); - std::partial_sum(rbegin(out_lengths), - std::prev(rend(out_lengths)), - std::next(rbegin(out_strides)), - std::multiplies<>{}); - - // transpose GNHWC/GKYXC/GNHWK to GNCHW/GKCYX/GNCHW - std::rotate( - rbegin(in_lengths), std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 3)); - std::rotate( - rbegin(in_strides), std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 3)); - std::rotate( - rbegin(wei_lengths), std::next(rbegin(wei_lengths)), std::next(rbegin(wei_lengths), 3)); - std::rotate( - rbegin(wei_strides), std::next(rbegin(wei_strides)), std::next(rbegin(wei_strides), 3)); - std::rotate( - rbegin(out_lengths), std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 3)); - std::rotate( - rbegin(out_strides), std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 3)); + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space + // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW + // Hence, we need to adjust the order of stride + std::array in_lengths{G, N, C, Hi, Wi}; + std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; + std::array wei_lengths{G, K, C, Y, X}; + std::array wei_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; + std::array out_lengths{G, N, K, Ho, Wo}; + std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; std::array filter_strides{1, 1}; std::array filter_dilations{1, 1}; std::array input_left_pads{1, 1}; std::array input_right_pads{1, 1}; - SimpleDeviceMem in(sizeof(InDataType) * G * N * Hi * Wi * C); + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); - SimpleDeviceMem out(sizeof(OutDataType) * G * N * Ho * Wo * K); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleDRun(argument_ptr.get(), StreamConfig{nullptr, true}); std::size_t flop = std::size_t(2) * G * N * K * C * Ho * Wo * Y * X; - std::size_t num_bytes = sizeof(InDataType) * G * N * Hi * Wi * C + + std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C + sizeof(WeiDataType) * G * K * Y * X * C + - sizeof(OutDataType) * G * N * Ho * Wo * K; + sizeof(OutDataType) * N * Ho * Wo * G * K; float tflops = static_cast(flop) / 1.E9 / avg_time; float gb_per_sec = num_bytes / 1.E6 / avg_time; diff --git a/client_example/08_fused_attention/fused_attention.cpp b/client_example/08_fused_attention/fused_attention.cpp index fe927da1248786a4b943f610ce38b75f0d88defd..df6bc11a70d32df221612466a8af0fbcd9cafb1c 100644 --- a/client_example/08_fused_attention/fused_attention.cpp +++ b/client_example/08_fused_attention/fused_attention.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/08_fused_attention/fused_attention_bias.cpp b/client_example/08_fused_attention/fused_attention_bias.cpp index 3113b7856025af74c610981db130a8d965c36a24..6c9f3bc8f6f5a3c06f339f1246b5b3985e11d2d8 100644 --- a/client_example/08_fused_attention/fused_attention_bias.cpp +++ b/client_example/08_fused_attention/fused_attention_bias.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/09_quantization/CMakeLists.txt b/client_example/09_quantization/CMakeLists.txt index a4dd80cd3fdd45c2987e997d17ed80a699a371fe..2b7d6fc806ad48f650c3e1f72ef3190cb9f342f4 100644 --- a/client_example/09_quantization/CMakeLists.txt +++ b/client_example/09_quantization/CMakeLists.txt @@ -1,6 +1,12 @@ +add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp) +target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_operations) + add_executable(client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp) target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_operations) +add_executable(client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp) +target_link_libraries(client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_operations) + add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp) target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_operations) diff --git a/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp index cf6807f0ddfe212bfa47adb54a64ee4b8bf663a4..cd504e942e943b8f174442554eca379cd908ba6e 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -17,26 +17,26 @@ using BiasDataType = int32_t; using RequantScaleDataType = float; using OutDataType = int8_t; -using InLayout = ck::tensor_layout::convolution::GNHWC; +using InLayout = ck::tensor_layout::convolution::NHWGC; using WeiLayout = ck::tensor_layout::convolution::GKYXC; using BiasLayout = ck::tensor_layout::convolution::G_K; using RequantScaleLayout = ck::tensor_layout::convolution::G_K; -using OutLayout = ck::tensor_layout::convolution::GNHWK; +using OutLayout = ck::tensor_layout::convolution::NHWGK; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ActivationOp = ck::tensor_operation::element_wise::Relu; using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp; static constexpr ck::index_t NumDimSpatial = 2; -static constexpr ck::index_t G = 1; -static constexpr ck::index_t N = 4; // batch size -static constexpr ck::index_t K = 64; // output channel -static constexpr ck::index_t C = 192; // input channel -static constexpr ck::index_t Y = 3; // filter H -static constexpr ck::index_t X = 3; // filter W -static constexpr ck::index_t Hi = 71; // input H -static constexpr ck::index_t Wi = 71; // input W -static constexpr ck::index_t Ho = 36; // output H -static constexpr ck::index_t Wo = 36; // output W +static constexpr ck::index_t G = 4; +static constexpr ck::index_t N = 4; // batch size +static constexpr ck::index_t K = 32; // output channel +static constexpr ck::index_t C = 64; // input channel (per group) +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 71; // input H +static constexpr ck::index_t Wi = 71; // input W +static constexpr ck::index_t Ho = 36; // output H +static constexpr ck::index_t Wo = 36; // output W struct SimpleDeviceMem { SimpleDeviceMem() = delete; @@ -55,8 +55,11 @@ struct SimpleDeviceMem int main(int argc, char* argv[]) { + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space + // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW + // Hence, we need to adjust the order of stride std::array in_lengths{G, N, C, Hi, Wi}; - std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; std::array weight_lengths{G, K, C, Y, X}; std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; std::array bias_lengths{G, N, K, Ho, Wo}; @@ -64,17 +67,18 @@ int main(int argc, char* argv[]) std::array requant_scale_lengths{G, N, K, Ho, Wo}; std::array requant_scale_strides{K, 0, 1, 0, 0}; std::array out_lengths{G, N, K, Ho, Wo}; - std::array out_strides{N * Ho * Wo * K, Ho * Wo * K, 1, Wo * K, K}; + std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; + std::array in_left_pad{1, 1}; std::array in_right_pad{1, 1}; std::array conv_strides{2, 2}; std::array conv_dilations{1, 1}; - SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); - SimpleDeviceMem bias(sizeof(BiasDataType) * K * Y * X * C); - SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K * Y * X * C); - SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); + SimpleDeviceMem bias(sizeof(BiasDataType) * G * K); + SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * G * K); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< NumDimSpatial, @@ -203,4 +207,4 @@ int main(int argc, char* argv[]) } return 0; -} \ No newline at end of file +} diff --git a/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp index 7cbbd28322949e7a8d1bd85da891b5d8108e854d..f4aa3666b1c15eccd37bdef549f8402bcd1b252b 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -16,25 +16,26 @@ using WeiDataType = int8_t; using BiasDataType = int32_t; using OutDataType = int8_t; -using InLayout = ck::tensor_layout::convolution::GNHWC; +using InLayout = ck::tensor_layout::convolution::NHWGC; using WeiLayout = ck::tensor_layout::convolution::GKYXC; using BiasLayout = ck::tensor_layout::convolution::G_K; -using OutLayout = ck::tensor_layout::convolution::GNHWK; +using OutLayout = ck::tensor_layout::convolution::NHWGK; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ActivationOp = ck::tensor_operation::element_wise::Relu; using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; static constexpr ck::index_t NumDimSpatial = 2; -static constexpr ck::index_t G = 1; -static constexpr ck::index_t N = 4; // batch size -static constexpr ck::index_t K = 64; // output channel -static constexpr ck::index_t C = 192; // input channel -static constexpr ck::index_t Y = 3; // filter H -static constexpr ck::index_t X = 3; // filter W -static constexpr ck::index_t Hi = 71; // input H -static constexpr ck::index_t Wi = 71; // input W -static constexpr ck::index_t Ho = 36; // output H -static constexpr ck::index_t Wo = 36; // output W +static constexpr ck::index_t G = 4; +static constexpr ck::index_t N = 4; // batch size +static constexpr ck::index_t K = 32; // output channel +static constexpr ck::index_t C = 64; // input channel (per group) +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 71; // input H +static constexpr ck::index_t Wi = 71; // input W +static constexpr ck::index_t Ho = 36; // output H +static constexpr ck::index_t Wo = 36; // output W +static constexpr float requant_scale = 0.5f; // requantize qAcc to qz struct SimpleDeviceMem { @@ -54,23 +55,27 @@ struct SimpleDeviceMem int main(int argc, char* argv[]) { + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space + // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW + // Hence, we need to adjust the order of stride std::array in_lengths{G, N, C, Hi, Wi}; - std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; std::array weight_lengths{G, K, C, Y, X}; std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; std::array bias_lengths{G, N, K, Ho, Wo}; std::array bias_strides{K, 0, 1, 0, 0}; std::array out_lengths{G, N, K, Ho, Wo}; - std::array out_strides{N * Ho * Wo * K, Ho * Wo * K, 1, Wo * K, K}; + std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; + std::array in_left_pad{1, 1}; std::array in_right_pad{1, 1}; std::array conv_strides{2, 2}; std::array conv_dilations{1, 1}; - SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); - SimpleDeviceMem bias(sizeof(BiasDataType) * K * Y * X * C); - SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); + SimpleDeviceMem bias(sizeof(BiasDataType) * G * K); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleDMakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {bias.GetDeviceBuffer()}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - weight_lengths, - weight_strides, - {bias_lengths}, - {bias_strides}, - out_lengths, - out_strides, - conv_strides, - conv_dilations, - in_left_pad, - in_right_pad, - PassThrough{}, - PassThrough{}, - OutElementOp{0.5f, ActivationOp{}}); + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = + op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {bias.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {bias_lengths}, + {bias_strides}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{requant_scale, ActivationOp{}}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); std::string op_name = op_ptr->GetTypeString(); @@ -165,25 +171,26 @@ int main(int argc, char* argv[]) auto& op_ptr = op_ptrs[best_op_id]; std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() << std::endl; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {bias.GetDeviceBuffer()}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - weight_lengths, - weight_strides, - {bias_lengths}, - {bias_strides}, - out_lengths, - out_strides, - conv_strides, - conv_dilations, - in_left_pad, - in_right_pad, - PassThrough{}, - PassThrough{}, - OutElementOp{0.5f, ActivationOp{}}); + auto argument_ptr = + op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {bias.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {bias_lengths}, + {bias_strides}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{requant_scale, ActivationOp{}}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); diff --git a/client_example/09_quantization/conv2d_fwd_bias_tanh_perchannel_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_tanh_perchannel_quantization.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ebdbbf52c0ca257dd8f672a48b32f80fa5bb8816 --- /dev/null +++ b/client_example/09_quantization/conv2d_fwd_bias_tanh_perchannel_quantization.cpp @@ -0,0 +1,213 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using BiasDataType = int32_t; +using RequantScaleDataType = float; +using OutDataType = int8_t; + +using InLayout = ck::tensor_layout::convolution::NHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using BiasLayout = ck::tensor_layout::convolution::G_K; +using RequantScaleLayout = ck::tensor_layout::convolution::G_K; +using OutLayout = ck::tensor_layout::convolution::NHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ActivationOp = ck::tensor_operation::element_wise::TanH; +using OutElementOp = + ck::tensor_operation::element_wise::Add_Mul2_Activation_Mul_Clamp; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 4; +static constexpr ck::index_t N = 4; // batch size +static constexpr ck::index_t K = 32; // output channel +static constexpr ck::index_t C = 64; // input channel (per group) +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 71; // input H +static constexpr ck::index_t Wi = 71; // input W +static constexpr ck::index_t Ho = 36; // output H +static constexpr ck::index_t Wo = 36; // output W +static constexpr float sz_inv = 0.5f; // inverse of scale_z + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space + // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW + // Hence, we need to adjust the order of stride + std::array in_lengths{G, N, C, Hi, Wi}; + std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; + std::array weight_lengths{G, K, C, Y, X}; + std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; + std::array bias_lengths{G, N, K, Ho, Wo}; + std::array bias_strides{K, 0, 1, 0, 0}; + std::array requant_scale_lengths{G, N, K, Ho, Wo}; + std::array requant_scale_strides{K, 0, 1, 0, 0}; + std::array out_lengths{G, N, K, Ho, Wo}; + std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; + + std::array in_left_pad{1, 1}; + std::array in_right_pad{1, 1}; + std::array conv_strides{2, 2}; + std::array conv_dilations{1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); + SimpleDeviceMem bias(sizeof(BiasDataType) * G * K); + SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * G * K); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< + NumDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + PassThrough, + PassThrough, + OutElementOp>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = + op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {bias.GetDeviceBuffer(), requant_scale.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {bias_lengths, requant_scale_lengths}, + {bias_strides, requant_scale_strides}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{sz_inv, ActivationOp{}}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = G * 2 * N * K * C * Ho * Wo * Y * X; + std::size_t num_bytes = + G * sizeof(InDataType) * N * Hi * Wi * C + G * sizeof(WeiDataType) * K * Y * X * C + + G * sizeof(BiasDataType) * K + G * sizeof(RequantScaleDataType) * K + + G * sizeof(OutDataType) * N * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(best_op_id != -1) + { + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = + op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {bias.GetDeviceBuffer(), requant_scale.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {bias_lengths, requant_scale_lengths}, + {bias_strides, requant_scale_strides}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{sz_inv, ActivationOp{}}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9d60baee06fddb27451f1f51642bad738739c4bb --- /dev/null +++ b/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp @@ -0,0 +1,205 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using BiasDataType = int32_t; +using OutDataType = int8_t; + +using InLayout = ck::tensor_layout::convolution::NHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using BiasLayout = ck::tensor_layout::convolution::G_K; +using OutLayout = ck::tensor_layout::convolution::NHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ActivationOp = ck::tensor_operation::element_wise::TanH; +using OutElementOp = ck::tensor_operation::element_wise::Add_Mul_Activation_Mul_Clamp; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 4; +static constexpr ck::index_t N = 4; // batch size +static constexpr ck::index_t K = 32; // output channel +static constexpr ck::index_t C = 64; // input channel (per group) +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 71; // input H +static constexpr ck::index_t Wi = 71; // input W +static constexpr ck::index_t Ho = 36; // output H +static constexpr ck::index_t Wo = 36; // output W +static constexpr float sacc = 0.5f; // scale of acc +static constexpr float sz_inv = 0.5f; // inverse of scale_z + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space + // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW + // Hence, we need to adjust the order of stride + std::array in_lengths{G, N, C, Hi, Wi}; + std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; + std::array weight_lengths{G, K, C, Y, X}; + std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; + std::array bias_lengths{G, N, K, Ho, Wo}; + std::array bias_strides{K, 0, 1, 0, 0}; + std::array out_lengths{G, N, K, Ho, Wo}; + std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; + + std::array in_left_pad{1, 1}; + std::array in_right_pad{1, 1}; + std::array conv_strides{2, 2}; + std::array conv_dilations{1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); + SimpleDeviceMem bias(sizeof(BiasDataType) * G * K); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); + + using DeviceOp = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + PassThrough, + PassThrough, + OutElementOp>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {bias.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {bias_lengths}, + {bias_strides}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{sacc, sz_inv, ActivationOp{}}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = G * 2 * N * K * C * Ho * Wo * Y * X; + std::size_t num_bytes = + G * sizeof(InDataType) * N * Hi * Wi * C + G * sizeof(WeiDataType) * K * Y * X * C + + G * sizeof(BiasDataType) * K + G * sizeof(OutDataType) * N * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(best_op_id != -1) + { + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {bias.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {bias_lengths}, + {bias_strides}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{sacc, sz_inv, ActivationOp{}}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} \ No newline at end of file diff --git a/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp b/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp index c1c5a651eb350a4016923809d417ca6eb5811eab..dd81d9ee6b6dc0ee453cac61d7e795c79ba3c9fa 100644 --- a/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -16,25 +16,25 @@ using WeiDataType = int8_t; using RequantScaleDataType = float; using OutDataType = int8_t; -using InLayout = ck::tensor_layout::convolution::GNHWC; +using InLayout = ck::tensor_layout::convolution::NHWGC; using WeiLayout = ck::tensor_layout::convolution::GKYXC; using RequantScaleLayout = ck::tensor_layout::convolution::G_K; -using OutLayout = ck::tensor_layout::convolution::GNHWK; +using OutLayout = ck::tensor_layout::convolution::NHWGK; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ActivationOp = PassThrough; using OutElementOp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp; static constexpr ck::index_t NumDimSpatial = 2; -static constexpr ck::index_t G = 1; -static constexpr ck::index_t N = 4; // batch size -static constexpr ck::index_t K = 64; // output channel -static constexpr ck::index_t C = 192; // input channel -static constexpr ck::index_t Y = 3; // filter H -static constexpr ck::index_t X = 3; // filter W -static constexpr ck::index_t Hi = 71; // input H -static constexpr ck::index_t Wi = 71; // input W -static constexpr ck::index_t Ho = 36; // output H -static constexpr ck::index_t Wo = 36; // output W +static constexpr ck::index_t G = 4; +static constexpr ck::index_t N = 4; // batch size +static constexpr ck::index_t K = 32; // output channel +static constexpr ck::index_t C = 64; // input channel (per group) +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 71; // input H +static constexpr ck::index_t Wi = 71; // input W +static constexpr ck::index_t Ho = 36; // output H +static constexpr ck::index_t Wo = 36; // output W struct SimpleDeviceMem { @@ -54,23 +54,27 @@ struct SimpleDeviceMem int main(int argc, char* argv[]) { + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space + // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW + // Hence, we need to adjust the order of stride std::array in_lengths{G, N, C, Hi, Wi}; - std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; std::array weight_lengths{G, K, C, Y, X}; std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; std::array requant_scale_lengths{G, N, K, Ho, Wo}; std::array requant_scale_strides{K, 0, 1, 0, 0}; std::array out_lengths{G, N, K, Ho, Wo}; - std::array out_strides{N * Ho * Wo * K, Ho * Wo * K, 1, Wo * K, K}; + std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; + std::array in_left_pad{1, 1}; std::array in_right_pad{1, 1}; std::array conv_strides{2, 2}; std::array conv_dilations{1, 1}; - SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); - SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K * Y * X * C); - SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); + SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * G * K); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD #include @@ -15,24 +15,25 @@ using InDataType = int8_t; using WeiDataType = int8_t; using OutDataType = int8_t; -using InLayout = ck::tensor_layout::convolution::GNHWC; +using InLayout = ck::tensor_layout::convolution::NHWGC; using WeiLayout = ck::tensor_layout::convolution::GKYXC; -using OutLayout = ck::tensor_layout::convolution::GNHWK; +using OutLayout = ck::tensor_layout::convolution::NHWGK; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ActivationOp = PassThrough; using OutElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; static constexpr ck::index_t NumDimSpatial = 2; -static constexpr ck::index_t G = 1; -static constexpr ck::index_t N = 4; // batch size -static constexpr ck::index_t K = 64; // output channel -static constexpr ck::index_t C = 192; // input channel -static constexpr ck::index_t Y = 3; // filter H -static constexpr ck::index_t X = 3; // filter W -static constexpr ck::index_t Hi = 71; // input H -static constexpr ck::index_t Wi = 71; // input W -static constexpr ck::index_t Ho = 36; // output H -static constexpr ck::index_t Wo = 36; // output W +static constexpr ck::index_t G = 4; +static constexpr ck::index_t N = 4; // batch size +static constexpr ck::index_t K = 32; // output channel +static constexpr ck::index_t C = 64; // input channel (per group) +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 71; // input H +static constexpr ck::index_t Wi = 71; // input W +static constexpr ck::index_t Ho = 36; // output H +static constexpr ck::index_t Wo = 36; // output W +static constexpr float requant_scale = 0.5f; // requantize qAcc to qY struct SimpleDeviceMem { @@ -52,20 +53,24 @@ struct SimpleDeviceMem int main(int argc, char* argv[]) { + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space + // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW + // Hence, we need to adjust the order of stride std::array in_lengths{G, N, C, Hi, Wi}; - std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; std::array weight_lengths{G, K, C, Y, X}; std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; std::array out_lengths{G, N, K, Ho, Wo}; - std::array out_strides{N * Ho * Wo * K, Ho * Wo * K, 1, Wo * K, K}; + std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; + std::array in_left_pad{1, 1}; std::array in_right_pad{1, 1}; std::array conv_strides{2, 2}; std::array conv_dilations{1, 1}; - SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); - SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleDMakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - weight_lengths, - weight_strides, - {}, - {}, - out_lengths, - out_strides, - conv_strides, - conv_dilations, - in_left_pad, - in_right_pad, - PassThrough{}, - PassThrough{}, - OutElementOp{0.5f, ActivationOp{}}); + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = + op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {}, + {}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{requant_scale, ActivationOp{}}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); std::string op_name = op_ptr->GetTypeString(); @@ -158,25 +164,26 @@ int main(int argc, char* argv[]) auto& op_ptr = op_ptrs[best_op_id]; std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() << std::endl; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - weight_lengths, - weight_strides, - {}, - {}, - out_lengths, - out_strides, - conv_strides, - conv_dilations, - in_left_pad, - in_right_pad, - PassThrough{}, - PassThrough{}, - OutElementOp{0.5f, ActivationOp{}}); + auto argument_ptr = + op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {}, + {}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{requant_scale, ActivationOp{}}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); diff --git a/client_example/09_quantization/gemm_quantization.cpp b/client_example/09_quantization/gemm_quantization.cpp index 242504b44ff942dcefdbf6cf7ffdc7d6b5b2bbdb..b14e68fa082f8f9d05ab5f00471a2fa82d1d113b 100644 --- a/client_example/09_quantization/gemm_quantization.cpp +++ b/client_example/09_quantization/gemm_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/10_grouped_conv2d_bwd_data/grouped_conv2d_bwd_data.cpp b/client_example/10_grouped_conv2d_bwd_data/grouped_conv2d_bwd_data.cpp index 55c789804230ccccf66d68be9244c5c4111451e6..1b2e8abc201c2aed2cd2eebccb68405a25033a43 100644 --- a/client_example/10_grouped_conv2d_bwd_data/grouped_conv2d_bwd_data.cpp +++ b/client_example/10_grouped_conv2d_bwd_data/grouped_conv2d_bwd_data.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/11_grouped_conv_bwd_weight/common.hpp b/client_example/11_grouped_conv_bwd_weight/common.hpp index a906263333c8a15947e12fe8702e431d4acb7999..62eb7bcf553f2ddd17181ce6697dcda70ae9cad6 100644 --- a/client_example/11_grouped_conv_bwd_weight/common.hpp +++ b/client_example/11_grouped_conv_bwd_weight/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp index de68f46d398958917e49bf14178f66414590ed86..bc4a6fe0bfa9e118bbd6ba32ecc7dd68f3b8b2c3 100644 --- a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp +++ b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp b/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp index 8ef21986a4d9a5f1eb25e21a1073c4cc341da88d..c0140f71c15993da2e0d4b629321a355c74f004a 100644 --- a/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp +++ b/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp b/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp index 322667a46bacae8d0c681939c3890ef9ff476b0e..365373343668409f130f642d8e217cfd48a9afac 100644 --- a/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp +++ b/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/13_batchnorm/batchnorm_infer_nhwc.cpp b/client_example/13_batchnorm/batchnorm_infer_nhwc.cpp index 3117d162db71a0a4d02a15decac20e1f0f50d56e..5e6627ce14d113224c7b0acb4fea69cb36c1f369 100644 --- a/client_example/13_batchnorm/batchnorm_infer_nhwc.cpp +++ b/client_example/13_batchnorm/batchnorm_infer_nhwc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/14_instance_id/batchnorm_fwd_instance_id.cpp b/client_example/14_instance_id/batchnorm_fwd_instance_id.cpp index 9cfeee1cfe106e69f83dc0184f3956d6751a2947..d45782d8e0ff37027a204c6820447286581f1138 100644 --- a/client_example/14_instance_id/batchnorm_fwd_instance_id.cpp +++ b/client_example/14_instance_id/batchnorm_fwd_instance_id.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp b/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp index 28524a9eee9b87db4223484497f8e2181ed366ea..c74d7c6bd8cf9cec54168792a8718768827d7b33 100644 --- a/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp +++ b/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/15_reduce/reduce_nhwc_c.cpp b/client_example/15_reduce/reduce_nhwc_c.cpp index 2275158bcb26d36d871c9e7086b45d4539584785..b45b72f0de0199daa88e9e42be320e2669398dfe 100644 --- a/client_example/15_reduce/reduce_nhwc_c.cpp +++ b/client_example/15_reduce/reduce_nhwc_c.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/16_convnd_fwd/common.hpp b/client_example/16_convnd_fwd/common.hpp index a6bb5aa65be098716b504ae9d0543ed9ae129dc0..449c9466e829baa8658f17fe8e9ae86c5a297238 100644 --- a/client_example/16_convnd_fwd/common.hpp +++ b/client_example/16_convnd_fwd/common.hpp @@ -141,14 +141,10 @@ bool run_grouped_conv_fwd(std::array( - {N, Di, Hi, Wi, G, C}, {K, Z, Y, X, G, C}, {N, Do, Ho, Wo, G, K}) + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) ? EXIT_SUCCESS : EXIT_FAILURE; } diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp b/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp index 43c98f1e9b8ae8b89893d5e430d4f4ab89678803..7e8c98b6037418001571707a82aa14987059fb4f 100644 --- a/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp +++ b/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp @@ -11,7 +11,7 @@ using WeiDataType = float; using OutDataType = float; using InLayout = ck::tensor_layout::convolution::NDHWGC; -using WeiLayout = ck::tensor_layout::convolution::KZYXGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; using OutLayout = ck::tensor_layout::convolution::NDHWGK; static constexpr ck::index_t NumDimSpatial = 3; @@ -38,7 +38,7 @@ int main() InLayout, WeiLayout, OutLayout>( - {N, Di, Hi, Wi, G, C}, {K, Z, Y, X, G, C}, {N, Do, Ho, Wo, G, K}) + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) ? EXIT_SUCCESS : EXIT_FAILURE; } diff --git a/client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp b/client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp index 223ed29be9ac66e8805d8d5b3e696bbac4f5c922..7ba3224fc3244ab3bd472b98864efa42ae132fde 100644 --- a/client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp +++ b/client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/18_groupnorm/CMakeLists.txt b/client_example/18_groupnorm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..17c88cb61bc5daaa3ac0c6a92459c147691be2fe --- /dev/null +++ b/client_example/18_groupnorm/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(client_groupnorm_swish groupnorm_swish.cpp) +target_link_libraries(client_groupnorm_swish PRIVATE composable_kernel::device_operations) diff --git a/client_example/18_groupnorm/groupnorm_swish.cpp b/client_example/18_groupnorm/groupnorm_swish.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e1d198d2282b85f21fb4c04afde443133c801b15 --- /dev/null +++ b/client_example/18_groupnorm/groupnorm_swish.cpp @@ -0,0 +1,194 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/normalization_swish.hpp" + +using XDataType = ck::half_t; +using GammaDataType = float; +using BetaDataType = float; +using YDataType = ck::half_t; +using ComputeDataType = float; +using Swish = ck::tensor_operation::element_wise::Swish; + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t N = 32; + ck::index_t H = 16; + ck::index_t W = 16; + ck::index_t G = 64; + ck::index_t C = 128; + + std::size_t xy_size = N * H * W * G * C; + std::size_t gamma_beta_size = G * C; + + std::vector xy_strides = {H * W * G * C, W * G * C, G * C, C, 1}; + std::vector gamma_beta_strides = {0, 0, 0, C, 1}; + + SimpleDeviceMem x_device_buf(sizeof(XDataType) * xy_size); + SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_beta_size); + SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * gamma_beta_size); + SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size); + + using DeviceOp = ck::tensor_operation::device::DeviceNormalization; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto& generic_op_ptr = op_ptrs[0]; + + auto generic_argument_ptr = + generic_op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths + xy_strides, // xStrides + gamma_beta_strides, // gammaStrides + gamma_beta_strides, // betaStrides + xy_strides, // yStrides + {1, 2, 4}, // reduceDims + 1e-6, + x_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + y_device_buf.GetDeviceBuffer(), + nullptr, + nullptr, + Swish{}); + + if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get())) + { + throw std::runtime_error( + "The generic kernel instance should be able to support any input shapes"); + }; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths + xy_strides, // xStrides + gamma_beta_strides, // gammaStrides + gamma_beta_strides, // betaStrides + xy_strides, // yStrides + {1, 2, 4}, // reduceDims + 1e-6, + x_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + y_device_buf.GetDeviceBuffer(), + nullptr, + nullptr, + Swish{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_byte = + sizeof(XDataType) * xy_size + sizeof(GammaDataType) * gamma_beta_size + + sizeof(BetaDataType) * gamma_beta_size + sizeof(YDataType) * xy_size; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(found) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths + xy_strides, // xStrides + gamma_beta_strides, // gammaStrides + gamma_beta_strides, // betaStrides + xy_strides, // yStrides + {1, 2, 4}, // reduceDims + 1e-6, + x_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + y_device_buf.GetDeviceBuffer(), + nullptr, + nullptr, + Swish{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/19_pool_fwd/CMakeLists.txt b/client_example/19_pool_fwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..13f9f73c83d55c801fdf4609e13f6b0813cb0c67 --- /dev/null +++ b/client_example/19_pool_fwd/CMakeLists.txt @@ -0,0 +1,5 @@ +add_executable(client_max_pool2d_fwd max_pool2d_fwd.cpp) +target_link_libraries(client_max_pool2d_fwd PRIVATE composable_kernel::device_operations) + +add_executable(client_avg_pool3d_fwd avg_pool3d_fwd.cpp) +target_link_libraries(client_avg_pool3d_fwd PRIVATE composable_kernel::device_operations) \ No newline at end of file diff --git a/client_example/19_pool_fwd/avg_pool3d_fwd.cpp b/client_example/19_pool_fwd/avg_pool3d_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2edaf474b5644a0a91fb51e675e1c37286841967 --- /dev/null +++ b/client_example/19_pool_fwd/avg_pool3d_fwd.cpp @@ -0,0 +1,199 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp" + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using IndexDataType = int32_t; + +constexpr ck::index_t InOutRank = 5; +constexpr ck::index_t WindowRank = 3; +#if 0 +constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; +constexpr bool OutputIndex = false; +#else +constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; +constexpr bool OutputIndex = false; +#endif + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t N = 2; + ck::index_t C = 32; + ck::index_t Z = 2; + ck::index_t Y = 2; + ck::index_t X = 2; + ck::index_t Di = 30; + ck::index_t Hi = 30; + ck::index_t Wi = 30; + ck::index_t window_stride_d = 2; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t in_left_pad_d = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_d = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Z) / window_stride_d + 1; + ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1; + ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1; + + // Pool API only support the order of NCDHW + std::vector in_length = {N, C, Di, Hi, Wi}; + std::vector out_length = {N, C, Do, Ho, Wo}; + std::vector window_spatial_lengths = {Z, Y, X}; + std::vector window_strides = {window_stride_d, window_stride_h, window_stride_w}; + std::vector input_left_pads = {in_left_pad_d, in_left_pad_h, in_left_pad_w}; + std::vector input_right_pads = {in_right_pad_d, in_right_pad_h, in_right_pad_w}; + + std::size_t in_tensor_size = N * C * Di * Hi * Wi; + std::size_t out_tensor_size = N * C * Do * Ho * Wo; + + // tensor layout = NDHWC + std::vector in_tensor_stride = {Di * C * Hi * Wi, 1, C * Hi * Wi, Wi * C, C}; + std::vector out_tensor_stride = {Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C}; + + SimpleDeviceMem in_device_buf(sizeof(InDataType) * in_tensor_size); + SimpleDeviceMem out_device_buf(sizeof(OutDataType) * out_tensor_size); + SimpleDeviceMem out_indices_device_buf(sizeof(IndexDataType) * out_tensor_size); + + using DeviceOp = ck::tensor_operation::device::DevicePoolFwd; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(out_indices_device_buf.GetDeviceBuffer()), + in_length, + window_spatial_lengths, + out_length, + in_tensor_stride, + out_tensor_stride, + out_tensor_stride, + window_strides, + input_left_pads, + input_right_pads, + {2, 3, 4}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_bytes = + in_tensor_size * sizeof(InDataType) + out_tensor_size * sizeof(OutDataType); + + if constexpr(OutputIndex) + num_bytes += out_tensor_size * sizeof(IndexDataType); + + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(found) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(out_indices_device_buf.GetDeviceBuffer()), + in_length, + window_spatial_lengths, + out_length, + in_tensor_stride, + out_tensor_stride, + out_tensor_stride, + window_strides, + input_left_pads, + input_right_pads, + {2, 3, 4}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/19_pool_fwd/max_pool2d_fwd.cpp b/client_example/19_pool_fwd/max_pool2d_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c776dc12da391de21daa2cf464d2bde39873db9f --- /dev/null +++ b/client_example/19_pool_fwd/max_pool2d_fwd.cpp @@ -0,0 +1,193 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/pool2d_fwd.hpp" + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using IndexDataType = int32_t; + +constexpr ck::index_t InOutRank = 4; +constexpr ck::index_t WindowRank = 2; +#if 1 +constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; +constexpr bool OutputIndex = true; +#else +constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; +constexpr bool OutputIndex = false; +#endif + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t N = 2; + ck::index_t C = 32; + ck::index_t Y = 2; + ck::index_t X = 2; + ck::index_t Hi = 30; + ck::index_t Wi = 30; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1; + ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1; + + // Pool API only support the order of NCHW + std::vector in_length = {N, C, Hi, Wi}; + std::vector out_length = {N, C, Ho, Wo}; + std::vector window_spatial_lengths = {Y, X}; + std::vector window_strides = {window_stride_h, window_stride_w}; + std::vector input_left_pads = {in_left_pad_h, in_left_pad_w}; + std::vector input_right_pads = {in_right_pad_h, in_right_pad_w}; + + std::size_t in_tensor_size = N * C * Hi * Wi; + std::size_t out_tensor_size = N * C * Ho * Wo; + + // tensor layout = NHWC + std::vector in_tensor_stride = {C * Hi * Wi, 1, Wi * C, C}; + std::vector out_tensor_stride = {C * Ho * Wo, 1, Wo * C, C}; + + SimpleDeviceMem in_device_buf(sizeof(InDataType) * in_tensor_size); + SimpleDeviceMem out_device_buf(sizeof(OutDataType) * out_tensor_size); + SimpleDeviceMem out_indices_device_buf(sizeof(IndexDataType) * out_tensor_size); + + using DeviceOp = ck::tensor_operation::device::DevicePoolFwd; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(out_indices_device_buf.GetDeviceBuffer()), + in_length, + window_spatial_lengths, + out_length, + in_tensor_stride, + out_tensor_stride, + out_tensor_stride, + window_strides, + input_left_pads, + input_right_pads, + {2, 3}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_bytes = + in_tensor_size * sizeof(InDataType) + out_tensor_size * sizeof(OutDataType); + + if constexpr(OutputIndex) + num_bytes += out_tensor_size * sizeof(IndexDataType); + + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(found) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(out_indices_device_buf.GetDeviceBuffer()), + in_length, + window_spatial_lengths, + out_length, + in_tensor_stride, + out_tensor_stride, + out_tensor_stride, + window_strides, + input_left_pads, + input_right_pads, + {2, 3}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index 87bcb08e83057b4f162870706a420a2d88d468bf..369cd0b54c1fac60b061cdafdb0b02b819a8accd 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -92,6 +92,7 @@ else() -Wno-unused-command-line-argument -Wno-weak-vtables -Wno-covered-switch-default + -Wno-unsafe-buffer-usage ) else() if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "GNU" AND ${COMPILER} MATCHES "CXX") diff --git a/cmake/googletest.cmake b/cmake/googletest.cmake index 3c6cb56ccea3490c6d7a4eaf6107749dafa87bdf..d6577ac33e71c575b2aa441da30838f58157f0b4 100644 --- a/cmake/googletest.cmake +++ b/cmake/googletest.cmake @@ -21,6 +21,7 @@ list(APPEND GTEST_CMAKE_CXX_FLAGS -Wno-comma -Wno-old-style-cast -Wno-deprecated + -Wno-unsafe-buffer-usage ) message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLAGS}") diff --git a/doc/markdown/dockerhub.md b/doc/markdown/dockerhub.md deleted file mode 100644 index 91b6cb2295cf4a8dd08ead2eb165119c74996084..0000000000000000000000000000000000000000 --- a/doc/markdown/dockerhub.md +++ /dev/null @@ -1,93 +0,0 @@ -## CK docker hub - -[Docker hub](https://hub.docker.com/r/rocm/composable_kernel) - -## Why do I need this? - -To make our lives easier and bring Composable Kernel dependencies together, we recommend using docker images. - -## So what is Composable Kernel? - -Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel languages, like HIP C++. - -To get the CK library - -``` -git clone https://github.com/ROCmSoftwarePlatform/composable_kernel.git -``` - -run a docker container - -``` -docker run \ --it \ ---privileged \ ---group-add sudo \ --w /root/workspace \ --v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ -rocm/composable_kernel:ck_ub20.04_rocm5.3_release \ -/bin/bash -``` - -and build the CK - -``` -mkdir build && cd build - -# Need to specify target ID, example below is for gfx908 and gfx90a -cmake \ --D CMAKE_PREFIX_PATH=/opt/rocm \ --D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ --D CMAKE_CXX_FLAGS="-O3" \ --D CMAKE_BUILD_TYPE=Release \ --D GPU_TARGETS="gfx908;gfx90a" \ -.. -``` - -and - -``` -make -j examples tests -``` - -To run all the test cases including tests and examples run - -``` -make test -``` - -We can also run specific examples or tests like - -``` -./bin/example_gemm_xdl_fp16 -./bin/test_gemm_fp16 -``` - -For more details visit [CK github repo](https://github.com/ROCmSoftwarePlatform/composable_kernel), [CK examples](https://github.com/ROCmSoftwarePlatform/composable_kernel/tree/develop/example), [even more CK examples](https://github.com/ROCmSoftwarePlatform/composable_kernel/tree/develop/client_example). - -## And what is inside? - -The docker images have everything you need for running CK including: - -* [ROCm](https://www.amd.com/en/graphics/servers-solutions-rocm) -* [CMake](https://cmake.org/) -* [Compiler](https://github.com/RadeonOpenCompute/llvm-project) - -## Which image is right for me? - -Let's take a look at the image naming, for example "ck_ub20.04_rocm5.4_release". The image specs are: - -* "ck" - made for running Composable Kernel -* "ub20.04" - based on Ubuntu 20.04 -* "rocm5.4" - ROCm platform version 5.4 -* "release" - compiler version is release - -So just pick the right image for your project dependencies and you're all set. - -## DIY starts here - -If you need to customize a docker image or just can't stop tinkering, feel free to adjust the [Dockerfile](https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/develop/Dockerfile) for your needs. - -## License - -CK is released under the MIT [license](https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/develop/LICENSE). diff --git a/doc/markdown/tutorial_hello_world.md b/doc/markdown/tutorial_hello_world.md deleted file mode 100644 index 297df10b5c65eaf2058c4fa436a4008c42cbf440..0000000000000000000000000000000000000000 --- a/doc/markdown/tutorial_hello_world.md +++ /dev/null @@ -1,191 +0,0 @@ -## CK Hello world - -## Motivation - -This tutorial is aimed at engineers dealing with artificial intelligence and machine learning who would like to optimize their pipelines and squeeze every performance drop by adding Composable Kernel (CK) library to their projects. We would like to make the CK library approachable so the tutorial is not based on the latest release and doesn't have all the bleeding edge features, but it will be reproducible now and forever. - -During this tutorial we will have an introduction to the CK library, we will build it and run some examples and tests, so to say we will run a "Hello world" example. In future tutorials we will go in depth and breadth and get familiar with other tools and ways to integrate CK into your project. - -## Description - -Modern AI technology solves more and more problems in all imaginable fields, but crafting fast and efficient workflows is still challenging. CK is one of the tools to make AI heavy lifting as fast and efficient as possible. CK is a collection of optimized AI operator kernels and tools to create new ones. The library has components required for majority of modern neural networks architectures including matrix multiplication, convolution, contraction, reduction, attention modules, variety of activation functions, fused operators and many more. - -So how do we (almost) reach the speed of light? CK acceleration abilities are based on: - -* Layered structure. -* Tile-based computation model. -* Tensor coordinate transformation. -* Hardware acceleration use. -* Support of low precision data types including fp16, bf16, int8 and int4. - -If you are excited and need more technical details and benchmarking results - read this awesome blog [post](https://community.amd.com/t5/instinct-accelerators/amd-composable-kernel-library-efficient-fused-kernels-for-ai/ba-p/553224). - -For more details visit our [github repo](https://github.com/ROCmSoftwarePlatform/composable_kernel). - -## Hardware targets - -CK library fully supports "gfx908" and "gfx90a" GPU architectures and only some operators are supported for "gfx1030". Let's check the hardware you have at hand and decide on the target GPU architecture - -GPU Target AMD GPU -gfx908 Radeon Instinct MI100 -gfx90a Radeon Instinct MI210, MI250, MI250X -gfx1030 Radeon PRO V620, W6800, W6800X, W6800X Duo, W6900X, RX 6800, RX 6800 XT, RX 6900 XT, RX 6900 XTX, RX 6950 XT - -There are also [cloud options](https://aws.amazon.com/ec2/instance-types/g4/) you can find if you don't have an AMD GPU at hand. - -## Build the library - -First let's clone the library and rebase to the tested version: - -``` -git clone https://github.com/ROCmSoftwarePlatform/composable_kernel.git -cd composable_kernel/ -git checkout tutorial_hello_world -``` - -To make our lives easier we prepared [docker images](https://hub.docker.com/r/rocm/composable_kernel) with all the necessary dependencies. Pick the right image and create a container. In this tutorial we use "rocm/composable_kernel:ck_ub20.04_rocm5.3_release" image, it is based on Ubuntu 20.04, ROCm v5.3, compiler release version. - -If your current folder is ${HOME}, start the docker container with - -``` -docker run \ --it \ ---privileged \ ---group-add sudo \ --w /root/workspace \ --v ${HOME}:/root/workspace \ -rocm/composable_kernel:ck_ub20.04_rocm5.3_release \ -/bin/bash -``` - -If your current folder is different from ${HOME}, adjust the line `-v ${HOME}:/root/workspace` to fit your folder structure. - -Inside the docker container current folder is "~/workspace", library path is "~/workspace/composable_kernel", navigate to the library - -``` -cd composable_kernel/ -``` - -Create and go to the "build" directory - -``` -mkdir build && cd build -``` - -In the previous section we talked about target GPU architecture. Once you decide which one is right for you, run cmake using the right GPU_TARGETS flag - -``` -cmake \ --D CMAKE_PREFIX_PATH=/opt/rocm \ --D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ --D CMAKE_CXX_FLAGS="-O3" \ --D CMAKE_BUILD_TYPE=Release \ --D BUILD_DEV=OFF \ --D GPU_TARGETS="gfx908;gfx90a;gfx1030" .. -``` - -If everything went well the cmake run will end up with: - -``` --- Configuring done --- Generating done --- Build files have been written to: "/root/workspace/composable_kernel/build" -``` - -Finally, we can build examples and tests - -``` -make -j examples tests -``` - -If everything is smooth, you'll see - -``` -Scanning dependencies of target tests -[100%] Built target tests -``` - -## Run examples and tests - -Examples are listed as test cases as well, so we can run all examples and tests with - -``` -ctest -``` - -You can check the list of all tests by running - -``` -ctest -N -``` - -We can also run them separately, here is a separate example execution. - -``` -./bin/example_gemm_xdl_fp16 1 1 1 -``` - -The arguments "1 1 1" mean that we want to run this example in the mode: verify results with CPU, initialize matrices with integers and benchmark the kernel execution. You can play around with these parameters and see how output and execution results change. - -If everything goes well and you have a device based on gfx908 or gfx90a architecture you should see something like - -``` -a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} -b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} -c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} -Warm up 1 time -Start running 10 times... -Perf: 1.10017 ms, 117.117 TFlops, 87.6854 GB/s, DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 -``` - -Meanwhile, running it on a gfx1030 device should result in - -``` -a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} -b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} -c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 does not support this problem -``` - -But don't panic, some of the operators are supported on gfx1030 architecture, so you can run a separate example like - -``` -./bin/example_gemm_dl_fp16 1 1 1 -``` - -and it should result in something nice similar to - -``` -a_m_k: dim 2, lengths {3840, 4096}, strides {1, 4096} -b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1} -c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -arg.a_grid_desc_k0_m0_m1_k1_{2048, 3840, 2} -arg.b_grid_desc_k0_n0_n1_k1_{2048, 4096, 2} -arg.c_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {960, 1, 1}, block_dim {256, 1, 1} -Warm up 1 time -Start running 10 times... -Perf: 3.65695 ms, 35.234 TFlops, 26.3797 GB/s, DeviceGemmDl<256, 128, 128, 16, 2, 4, 4, 1> -``` - -Or we can run a separate test - -``` -ctest -R test_gemm_fp16 -``` - -If everything goes well you should see something like - -``` -Start 121: test_gemm_fp16 -1/1 Test #121: test_gemm_fp16 ................... Passed 51.81 sec - -100% tests passed, 0 tests failed out of 1 -``` - -## Summary - -In this tutorial we took the first look at the Composable Kernel library, built it on your system and ran some examples and tests. Stay tuned, in the next tutorial we will run kernels with different configs to find out the best one for your hardware and task. - -P.S.: Don't forget to switch out the cloud instance if you have launched one, you can find better ways to spend your money for sure! diff --git a/docs/source/API_Reference_Guide.rst b/docs/API_Reference_Guide.rst similarity index 98% rename from docs/source/API_Reference_Guide.rst rename to docs/API_Reference_Guide.rst index 3665049dd6f98237b2118512f920b670c4cd831c..b59c6e302692cd6b9d4035ed27d5901e790a0233 100644 --- a/docs/source/API_Reference_Guide.rst +++ b/docs/API_Reference_Guide.rst @@ -49,4 +49,4 @@ used in the CK GPU implementation of Flashattention. .. doxygenstruct:: ck::ThreadwiseTensorSliceTransfer_StaticToStatic -.. bibliography:: \ No newline at end of file +.. bibliography:: diff --git a/docs/source/Contributors_Guide.rst b/docs/Contributors_Guide.rst similarity index 100% rename from docs/source/Contributors_Guide.rst rename to docs/Contributors_Guide.rst diff --git a/docs/source/Supported_Primitives_Guide.rst b/docs/Supported_Primitives_Guide.rst similarity index 99% rename from docs/source/Supported_Primitives_Guide.rst rename to docs/Supported_Primitives_Guide.rst index 066e024bc0b8a174d645cc7c065f92d9778bff57..4c3adf67d7119e22a622373ab8d1dccb4d024e18 100644 --- a/docs/source/Supported_Primitives_Guide.rst +++ b/docs/Supported_Primitives_Guide.rst @@ -72,4 +72,4 @@ Else if :math:`j>1`, \tilde{Y}_{ij} &= \diag(z^{new}_{i})^{-1} \exp(\tilde{m}_{ij} - m^{new}_i ) \tilde{P}_{ij} \\ z_i &= z^{new}_i \\ m_i &= m^{new}_i \\ - \end{align} \ No newline at end of file + \end{align} diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..0de590da1a53daba8dc0e379e68abbae250e0667 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,36 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +import subprocess + +from rocm_docs import ROCmDocs + + +name = "Composable Kernel" +get_version = r'sed -n -e "s/^rocm_setup_version(.* \([0-9\.]\{1,\}\).*/\1/p" ../CMakeLists.txt' +version = subprocess.getoutput(get_version) +if len(version) > 0: + name = f"{name} {version}" + +external_toc_path = "./sphinx/_toc.yml" + +docs_core = ROCmDocs(f"{name} Documentation") +docs_core.run_doxygen(doxygen_root="doxygen", doxygen_path="doxygen/docBin/xml") +docs_core.setup() + +mathjax3_config = { +'tex': { + 'macros': { + 'diag': '\\operatorname{diag}', + } + } +} + +for sphinx_var in ROCmDocs.SPHINX_VARS: + globals()[sphinx_var] = getattr(docs_core, sphinx_var) + +extensions += ['sphinxcontrib.bibtex'] +bibtex_bibfiles = ['refs.bib'] diff --git a/doc/image/ck_component.png b/docs/data/ck_component.png similarity index 100% rename from doc/image/ck_component.png rename to docs/data/ck_component.png diff --git a/doc/image/ck_layer.png b/docs/data/ck_layer.png similarity index 100% rename from doc/image/ck_layer.png rename to docs/data/ck_layer.png diff --git a/docs/source/dockerhub.rst b/docs/dockerhub.rst similarity index 100% rename from docs/source/dockerhub.rst rename to docs/dockerhub.rst diff --git a/docs/Doxyfile b/docs/doxygen/Doxyfile similarity index 99% rename from docs/Doxyfile rename to docs/doxygen/Doxyfile index ca354598b239c2f9119682a695dfdd5c65d90529..1084f94c81bbef88a3373d6e9e5d24b89e7b5353 100644 --- a/docs/Doxyfile +++ b/docs/doxygen/Doxyfile @@ -51,7 +51,7 @@ PROJECT_BRIEF = "prototype interfaces compatible with ROCm platform and # pixels and the maximum width should not exceed 200 pixels. Doxygen will copy # the logo to the output directory. -PROJECT_LOGO = ./rocm.jpg +PROJECT_LOGO = # The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path # into which the generated documentation will be written. If a relative path is @@ -775,10 +775,10 @@ WARN_LOGFILE = # spaces. See also FILE_PATTERNS and EXTENSION_MAPPING # Note: If this tag is empty the current directory is searched. -INPUT = ../include/ck/tensor_operation/gpu/grid \ - ../include/ck/tensor_operation/gpu/block \ - ../include/ck/tensor_operation/gpu/thread \ - ../library/include/ck/library/utility +INPUT = ../../include/ck/tensor_operation/gpu/grid \ + ../../include/ck/tensor_operation/gpu/block \ + ../../include/ck/tensor_operation/gpu/thread \ + ../../library/include/ck/library/utility # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..f4e66c1b51f4528fa47057116a8f457680b3d972 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,52 @@ +============================ +Composable Kernel User Guide +============================ + +------------ +Introduction +------------ + +This document contains instructions for installing, using, and contributing to Composable Kernel (CK). + +----------- +Methodology +----------- + +Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel languages, like HIP C++. + +CK utilizes two concepts to achieve performance portability and code maintainability: + +* A tile-based programming model +* Algorithm complexity reduction for complex ML operators, using innovative technique we call "Tensor Coordinate Transformation". + +.. image:: data/ck_component.png + :alt: CK Components + +-------------- +Code Structure +-------------- + +Current CK library are structured into 4 layers: + +* "Templated Tile Operators" layer +* "Templated Kernel and Invoker" layer +* "Instantiated Kernel and Invoker" layer +* "Client API" layer + +.. image:: data/ck_layer.png + :alt: CK Layers + +Documentation Roadmap +^^^^^^^^^^^^^^^^^^^^^ +The following is a list of CK documents in the suggested reading order: + +.. toctree:: + :maxdepth: 5 + :caption: Contents: + :numbered: + + tutorial_hello_world + dockerhub + Supported_Primitives_Guide + API_Reference_Guide + Contributors_Guide diff --git a/docs/license.rst b/docs/license.rst new file mode 100644 index 0000000000000000000000000000000000000000..ddb544496e41b92dcaadb6ff816946e46ca16da8 --- /dev/null +++ b/docs/license.rst @@ -0,0 +1,6 @@ +======= +License +======= + +.. include:: ../LICENSE + :literal: diff --git a/docs/source/refs.bib b/docs/refs.bib similarity index 100% rename from docs/source/refs.bib rename to docs/refs.bib diff --git a/docs/run_doc.sh b/docs/run_doc.sh deleted file mode 100755 index 58b0936c67882575e58a0cf624f48e5df0aeb3be..0000000000000000000000000000000000000000 --- a/docs/run_doc.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -set -eu - -# Make this directory the PWD -cd "$(dirname "${BASH_SOURCE[0]}")" - -# Build doxygen info -bash run_doxygen.sh - -# Build sphinx docs -cd source -make clean -make -e SPHINXOPTS="-t html" html -make latexpdf diff --git a/docs/run_doxygen.sh b/docs/run_doxygen.sh deleted file mode 100755 index f66c038c1bb7eb22427ff4619c2fbed2f78c4733..0000000000000000000000000000000000000000 --- a/docs/run_doxygen.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/bash - -set -eu - -# Make this directory the PWD -cd "$(dirname "${BASH_SOURCE[0]}")" - -# Build the doxygen info -rm -rf docBin -doxygen Doxyfile diff --git a/docs/source/Disclaimer.rst b/docs/source/Disclaimer.rst deleted file mode 100644 index 5dcff748c87a9d9616b3ec657d5bb5d1a743c456..0000000000000000000000000000000000000000 --- a/docs/source/Disclaimer.rst +++ /dev/null @@ -1,13 +0,0 @@ -************ -Disclaimer -************ -------------------------------- -AMD's standard legal Disclaimer -------------------------------- - -The information presented in this document is for informational purposes only and may contain technical inaccuracies, omissions, and typographical errors. The information contained herein is subject to change and may be rendered inaccurate for many reasons, including but not limited to product and roadmap changes, component and motherboard version changes, new model and/or product releases, product differences between differing manufacturers, software changes, BIOS flashes, firmware upgrades, or the like. Any computer system has risks of security vulnerabilities that cannot be completely prevented or mitigated. AMD assumes no obligation to update or otherwise correct or revise this information. However, AMD reserves the right to revise this information and to make changes from time to time to the content hereof without obligation of AMD to notify any person of such revisions or changes. THIS INFORMATION IS PROVIDED 'AS IS." AMD MAKES NO REPRESENTATIONS OR WARRANTIES WITH RESPECT TO THE CONTENTS HEREOF AND ASSUMES NO RESPONSIBILITY FOR ANY INACCURACIES, ERRORS, OR OMISSIONS THAT MAY APPEAR IN THIS INFORMATION. AMD SPECIFICALLY DISCLAIMS ANY IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR ANY PARTICULAR PURPOSE. IN NO EVENT WILL AMD BE LIABLE TO ANY PERSON FOR ANY RELIANCE, DIRECT, INDIRECT, SPECIAL, OR OTHER CONSEQUENTIAL DAMAGES ARISING FROM THE USE OF ANY INFORMATION CONTAINED HEREIN, EVEN IF AMD IS EXPRESSLY ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. AMD, the AMD Arrow logo, Radeon, Ryzen, Epyc, and combinations thereof are trademarks of Advanced Micro Devices, Inc. Other product names used in this publication are for identification purposes only and may be trademarks of their respective companies. Google(R) is a registered trademark of Google LLC. PCIe(R) is a registered trademark of PCI-SIG Corporation. Linux(R) is the registered trademark of Linus Torvalds in the U.S. and other countries. Ubuntu(R) and the Ubuntu logo are registered trademarks of Canonical Ltd. Other product names used in this publication are for identification purposes only and may be trademarks of their respective companies. (C)2023 Advanced Micro Devices, Inc. All rights reserved. - ----------------------- -Third Party Disclaimer ----------------------- -Third-party content is licensed to you directly by the third party that owns the content and is not licensed to you by AMD. ALL LINKED THIRD-PARTY CONTENT IS PROVIDED "AS IS" WITHOUT A WARRANTY OF ANY KIND. USE OF SUCH THIRD-PARTY CONTENT IS DONE AT YOUR SOLE DISCRETION AND UNDER NO CIRCUMSTANCES WILL AMD BE LIABLE TO YOU FOR ANY THIRD-PARTY CONTENT. YOU ASSUME ALL RISK AND ARE SOLELY RESPONSIBLE FOR ANY DAMAGES THAT MAY ARISE FROM YOUR USE OF THIRD-PARTY CONTENT. diff --git a/docs/source/Linux_Install_Guide.rst b/docs/source/Linux_Install_Guide.rst deleted file mode 100644 index 0e16bb6a9863edd3872228fdf8683f524055c1e6..0000000000000000000000000000000000000000 --- a/docs/source/Linux_Install_Guide.rst +++ /dev/null @@ -1,15 +0,0 @@ -===================== -Getting Started Guide -===================== - ------------- -Introduction ------------- - -This document contains instructions for installing, using, and contributing to Composable Kernel (CK). - -Documentation Roadmap -^^^^^^^^^^^^^^^^^^^^^ -The following is a list of CK documents in the suggested reading order: - -[TODO] \ No newline at end of file diff --git a/docs/source/Makefile b/docs/source/Makefile deleted file mode 100644 index bde66ebc258354f0c0a902e3b9bc371d9fbf2117..0000000000000000000000000000000000000000 --- a/docs/source/Makefile +++ /dev/null @@ -1,20 +0,0 @@ -# Minimal makefile for Sphinx documentation -# - -# You can set these variables from the command line. -SPHINXOPTS = -SPHINXBUILD = sphinx-build -SPHINXPROJ = CK -SOURCEDIR = . -BUILDDIR = _build - -# Put it first so that "make" without argument is like "make help". -help: - @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - -.PHONY: help Makefile - -# Catch-all target: route all unknown targets to Sphinx using the new -# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/source/conf.py b/docs/source/conf.py deleted file mode 100644 index 65ac187034b59422bcf332df33eeaaac6db087ac..0000000000000000000000000000000000000000 --- a/docs/source/conf.py +++ /dev/null @@ -1,219 +0,0 @@ -"""Copyright (C) 2018-2023 Advanced Micro Devices, Inc. All rights reserved. - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- - ies of the Software, and to permit persons to whom the Software is furnished - to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- - PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS - FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR - COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER - IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- - CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -""" - -# -*- coding: utf-8 -*- -# -# Composable Kernel (CK) docuumentation build configuration file, based on -# rocBLAS documentation build configuration file, created by -# sphinx-quickstart on Mon Jan 8 16:34:42 2018. -# -# This file is execfile()d with the current directory set to its -# containing dir. -# -# Note that not all possible configuration values are present in this -# autogenerated file. -# -# All configuration values have a default; values that are commented out -# serve to show the default. - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -# import os -# import sys -# sys.path.insert(0, os.path.abspath('.')) - -import os -import sys -import subprocess - -read_the_docs_build = os.environ.get('READTHEDOCS', None) == 'True' - -if read_the_docs_build: - subprocess.call('../run_doxygen.sh') - -# -- General configuration ------------------------------------------------ - -# If your documentation needs a minimal Sphinx version, state it here. -# -# needs_sphinx = '1.0' - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -extensions = ['sphinx.ext.mathjax', 'breathe', 'sphinxcontrib.bibtex'] - -breathe_projects = { "CK": "../docBin/xml" } -breathe_default_project = "CK" - -bibtex_bibfiles = ['refs.bib'] - -# Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] - -# The suffix(es) of source filenames. -# You can specify multiple suffix as a list of string: -# -# source_suffix = ['.rst', '.md'] -source_suffix = '.rst' - -# The master toctree document. -master_doc = 'index' - -# General information about the project. -project = u'Composable Kernel (CK)' -copyright = u'2018-2023, Advanced Micro Devices' -author = u'Advanced Micro Devices' - -# The version info for the project you're documenting, acts as replacement for -# |version| and |release|, also used in various other places throughout the -# built documents. -# -# The short X.Y version. -#version = u'0.8' -# The full version, including alpha/beta/rc tags. -#release = u'0.8' - -# The language for content autogenerated by Sphinx. Refer to documentation -# for a list of supported languages. -# -# This is also used if you do content translation via gettext catalogs. -# Usually you set "language" from the command line for these cases. -language = 'en' - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This patterns also effect to html_static_path and html_extra_path -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] - -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' - -# If true, `todo` and `todoList` produce output, else they produce nothing. -todo_include_todos = False - - -# -- Options for HTML output ---------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -# html_theme = 'alabaster' - -#if read_the_docs_build: -# html_theme = 'default' -#else: -import sphinx_rtd_theme -html_theme = "sphinx_rtd_theme" -html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] -html_logo = "rocm_logo.png" - -# Theme options are theme-specific and customize the look and feel of a theme -# further. For a list of options available for each theme, see the -# documentation. -html_theme_options = { - 'logo_only': True, - 'display_version': True -} - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -#html_static_path = ['_static'] - -# Custom sidebar templates, must be a dictionary that maps document names -# to template names. -# -# This is required for the alabaster theme -# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars -# html_sidebars = { -# '**': [ -# 'relations.html', # needs 'show_related': True theme option to display -# 'searchbox.html', -# ] -# } - -mathjax3_config = { -'tex': { - 'macros': { - 'diag': '\\operatorname{diag}', - } - } -} - -# -- Options for HTMLHelp output ------------------------------------------ - -# Output file base name for HTML help builder. -htmlhelp_basename = 'CKdoc' - - -# -- Options for LaTeX output --------------------------------------------- - -latex_elements = { - # The paper size ('letterpaper' or 'a4paper'). - # - 'papersize': 'letterpaper', - - # The font size ('10pt', '11pt' or '12pt'). - # - 'pointsize': '10pt', - - # Additional stuff for the LaTeX preamble. - # - 'preamble': r''' -\setcounter{tocdepth}{5} -\newcommand{\diag}{\operatorname{diag}} -''', - - # Latex figure (float) alignment - # - # 'figure_align': 'htbp', -} - -# Grouping the document tree into LaTeX files. List of tuples -# (source start file, target name, title, -# author, documentclass [howto, manual, or own class]). -latex_documents = [ - (master_doc, 'CK.tex', u'Composabl Kernel (CK) Documentation', - u'Advanced Micro Devices', 'manual'), -] - - -# -- Options for manual page output --------------------------------------- - -# One entry per manual page. List of tuples -# (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'ck', u'Composable Kernel (CK) Documentation', - [author], 1) -] - - -# -- Options for Texinfo output ------------------------------------------- - -# Grouping the document tree into Texinfo files. List of tuples -# (source start file, target name, title, author, -# dir menu entry, description, category) -texinfo_documents = [ - (master_doc, 'CK', u'Composable Kernel (CK) Documentation', - author, 'CK', 'Composable Kernel for AMD ROCm', - 'Miscellaneous'), -] diff --git a/docs/source/index.rst b/docs/source/index.rst deleted file mode 100644 index 68adf58afd8c620a2dc0d046b8d358451949ee76..0000000000000000000000000000000000000000 --- a/docs/source/index.rst +++ /dev/null @@ -1,16 +0,0 @@ -============================ -Composable Kernel User Guide -============================ - -.. toctree:: - :maxdepth: 5 - :caption: Contents: - :numbered: - - Linux_Install_Guide - tutorial_hello_world - dockerhub - Supported_Primitives_Guide - API_Reference_Guide - Contributors_Guide - Disclaimer \ No newline at end of file diff --git a/docs/source/rocm_logo.png b/docs/source/rocm_logo.png deleted file mode 100644 index ee09dd09c71e3de1081f0f4c8b67967f79d768fe..0000000000000000000000000000000000000000 Binary files a/docs/source/rocm_logo.png and /dev/null differ diff --git a/docs/sphinx/_toc.yml.in b/docs/sphinx/_toc.yml.in new file mode 100644 index 0000000000000000000000000000000000000000..83dd1e7b1a5d759143f2db13a1e296f5e66da96d --- /dev/null +++ b/docs/sphinx/_toc.yml.in @@ -0,0 +1,10 @@ +# Anywhere {branch} is used, the branch name will be substituted. +# These comments will also be removed. +defaults: + numbered: False + maxdepth: 6 +root: index +subtrees: + - caption: About + entries: + - file: license diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in new file mode 100644 index 0000000000000000000000000000000000000000..4bdf41b959975f9b7f6ba38d20d5070f83e4da1e --- /dev/null +++ b/docs/sphinx/requirements.in @@ -0,0 +1,2 @@ +rocm-docs-core==0.10.3 +sphinxcontrib-bibtex==2.5.0 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..097acba2257d3c3f1c03d801ebeebdccd0dd120f --- /dev/null +++ b/docs/sphinx/requirements.txt @@ -0,0 +1,172 @@ +# +# This file is autogenerated by pip-compile with Python 3.8 +# by the following command: +# +# pip-compile requirements.in +# +accessible-pygments==0.0.3 + # via pydata-sphinx-theme +alabaster==0.7.13 + # via sphinx +babel==2.12.1 + # via + # pydata-sphinx-theme + # sphinx +beautifulsoup4==4.11.2 + # via pydata-sphinx-theme +breathe==4.34.0 + # via rocm-docs-core +certifi==2022.12.7 + # via requests +cffi==1.15.1 + # via + # cryptography + # pynacl +charset-normalizer==3.1.0 + # via requests +click==8.1.3 + # via sphinx-external-toc +cryptography==40.0.2 + # via pyjwt +deprecated==1.2.13 + # via pygithub +docutils==0.16 + # via + # breathe + # myst-parser + # pybtex-docutils + # pydata-sphinx-theme + # sphinx + # sphinxcontrib-bibtex +gitdb==4.0.10 + # via gitpython +gitpython==3.1.31 + # via rocm-docs-core +idna==3.4 + # via requests +imagesize==1.4.1 + # via sphinx +importlib-metadata==6.0.0 + # via + # sphinx + # sphinxcontrib-bibtex +importlib-resources==5.12.0 + # via rocm-docs-core +jinja2==3.1.2 + # via + # myst-parser + # sphinx +latexcodec==2.0.1 + # via pybtex +linkify-it-py==1.0.3 + # via myst-parser +markdown-it-py==2.2.0 + # via + # mdit-py-plugins + # myst-parser +markupsafe==2.1.2 + # via jinja2 +mdit-py-plugins==0.3.5 + # via myst-parser +mdurl==0.1.2 + # via markdown-it-py +myst-parser[linkify]==1.0.0 + # via rocm-docs-core +packaging==23.0 + # via + # pydata-sphinx-theme + # sphinx +pybtex==0.24.0 + # via + # pybtex-docutils + # sphinxcontrib-bibtex +pybtex-docutils==1.0.2 + # via sphinxcontrib-bibtex +pycparser==2.21 + # via cffi +pydata-sphinx-theme==0.13.3 + # via + # rocm-docs-core + # sphinx-book-theme +pygithub==1.58.2 + # via rocm-docs-core +pygments==2.14.0 + # via + # accessible-pygments + # pydata-sphinx-theme + # sphinx +pyjwt[crypto]==2.6.0 + # via pygithub +pynacl==1.5.0 + # via pygithub +pytz==2023.3 + # via babel +pyyaml==6.0 + # via + # myst-parser + # pybtex + # sphinx-external-toc +requests==2.28.2 + # via + # pygithub + # sphinx +rocm-docs-core==0.10.3 + # via -r requirements.in +six==1.16.0 + # via + # latexcodec + # pybtex +smmap==5.0.0 + # via gitdb +snowballstemmer==2.2.0 + # via sphinx +soupsieve==2.4 + # via beautifulsoup4 +sphinx==5.3.0 + # via + # breathe + # myst-parser + # pydata-sphinx-theme + # rocm-docs-core + # sphinx-book-theme + # sphinx-copybutton + # sphinx-design + # sphinx-external-toc + # sphinx-notfound-page + # sphinxcontrib-bibtex +sphinx-book-theme==1.0.1 + # via rocm-docs-core +sphinx-copybutton==0.5.1 + # via rocm-docs-core +sphinx-design==0.3.0 + # via rocm-docs-core +sphinx-external-toc==0.3.1 + # via rocm-docs-core +sphinx-notfound-page==0.8.3 + # via rocm-docs-core +sphinxcontrib-applehelp==1.0.4 + # via sphinx +sphinxcontrib-bibtex==2.5.0 + # via -r requirements.in +sphinxcontrib-devhelp==1.0.2 + # via sphinx +sphinxcontrib-htmlhelp==2.0.1 + # via sphinx +sphinxcontrib-jsmath==1.0.1 + # via sphinx +sphinxcontrib-qthelp==1.0.3 + # via sphinx +sphinxcontrib-serializinghtml==1.1.5 + # via sphinx +typing-extensions==4.5.0 + # via pydata-sphinx-theme +uc-micro-py==1.0.1 + # via linkify-it-py +urllib3==1.26.15 + # via requests +wrapt==1.15.0 + # via deprecated +zipp==3.15.0 + # via + # importlib-metadata + # importlib-resources diff --git a/docs/source/tutorial_hello_world.rst b/docs/tutorial_hello_world.rst similarity index 100% rename from docs/source/tutorial_hello_world.rst rename to docs/tutorial_hello_world.rst diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index 495a8159623beb22d6002bb2e1667ef459b74139..144c9aaccd7f8b1e20e30e671be4be2525ee1d03 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/01_gemm/gemm_dl_fp16.cpp b/example/01_gemm/gemm_dl_fp16.cpp index cf585a8c51cbb9b5b0218228e4aa706189598804..b5fecb97521bd2d4213a9435b01dafe45a43be6e 100644 --- a/example/01_gemm/gemm_dl_fp16.cpp +++ b/example/01_gemm/gemm_dl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/gemm_dl_fp32.cpp b/example/01_gemm/gemm_dl_fp32.cpp index 93f085cdee53667a5b906cfed3b037d57d00bf5f..212b72f2a6a060a602c5c640709da77684c7c55e 100644 --- a/example/01_gemm/gemm_dl_fp32.cpp +++ b/example/01_gemm/gemm_dl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/gemm_dl_int4.cpp b/example/01_gemm/gemm_dl_int4.cpp index e392c490f29a48da3a6424f46876e220306f2907..e55ae140130c0779c1e81e0bdc84b6724c6bac86 100644 --- a/example/01_gemm/gemm_dl_int4.cpp +++ b/example/01_gemm/gemm_dl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #error Should compile this file with ck::int4_t support diff --git a/example/01_gemm/gemm_dl_int8.cpp b/example/01_gemm/gemm_dl_int8.cpp index be9e387718f120fb1ba708d374f0ff9a09fc806d..1840390aa9e02e85c88328baf358546db362ab24 100644 --- a/example/01_gemm/gemm_dl_int8.cpp +++ b/example/01_gemm/gemm_dl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp index 58f965be8817bb5761dd677a52966d21c1bc1024..b11fe76ab2ce031708be7addd7f84be0f15adc6b 100644 --- a/example/01_gemm/gemm_wmma_fp16.cpp +++ b/example/01_gemm/gemm_wmma_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/gemm_xdl_bf16.cpp b/example/01_gemm/gemm_xdl_bf16.cpp index 9aaae6ade9564aa54bc4ea0c7f2c96aac89ff505..3cac55ef4702856dd08dc223c02ea25113dbbf32 100644 --- a/example/01_gemm/gemm_xdl_bf16.cpp +++ b/example/01_gemm/gemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index 50d35fd9ac98de821fe4861b2dbf34907d3592eb..54fbd9cdd4983c1a810f3c8a1ac1521b363af690 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/gemm_xdl_fp64.cpp b/example/01_gemm/gemm_xdl_fp64.cpp index 99253b743d58707707c9765c18e5933cb09e4220..8361576299c3da3727cc467387e9717fceb93ce3 100644 --- a/example/01_gemm/gemm_xdl_fp64.cpp +++ b/example/01_gemm/gemm_xdl_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/gemm_xdl_int4.cpp b/example/01_gemm/gemm_xdl_int4.cpp index 7f1283a47b36c94be9645abd9fbd63094293f46e..f6238c7aa5040d8e409a139fe8c144835282cdc9 100644 --- a/example/01_gemm/gemm_xdl_int4.cpp +++ b/example/01_gemm/gemm_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #error Should compile this file with ck::int4_t support diff --git a/example/01_gemm/gemm_xdl_int8.cpp b/example/01_gemm/gemm_xdl_int8.cpp index e67594c5bcbd601ae4747ac7720dab66b605ba2c..cc03200b9d153e8e22e13e6c77a52ad76ed79174 100644 --- a/example/01_gemm/gemm_xdl_int8.cpp +++ b/example/01_gemm/gemm_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp b/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp index 12a69925977b7eef2e5aa178016a7d433e9c0781..3afd0ebdb955dbcf949cfa2467a85cdf325c5f4f 100644 --- a/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp +++ b/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/gemm_xdl_wavelet_fp16.cpp b/example/01_gemm/gemm_xdl_wavelet_fp16.cpp index 3a0ddd90b700f7086b44a8b0f4e5b9d7385d0ad2..d7176f75daa65da438565fc6c7c3923873485a6d 100644 --- a/example/01_gemm/gemm_xdl_wavelet_fp16.cpp +++ b/example/01_gemm/gemm_xdl_wavelet_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 4e2cedb52adc8c7030167ec7f171d91f0796f234..38c72afc60c693ff25fbf8d6679ea7f22ebfdb13 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/02_gemm_bilinear/CMakeLists.txt b/example/02_gemm_bilinear/CMakeLists.txt index 16a82110276b7901df9e80fe3a8e3ccb64df22aa..7ea4564d1b7e6562c4f34197cc691c0f477610df 100644 --- a/example/02_gemm_bilinear/CMakeLists.txt +++ b/example/02_gemm_bilinear/CMakeLists.txt @@ -1,4 +1,17 @@ -add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp) -if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") +list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102) +list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list1 AND target EQUAL 0) add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp) -endif() + set(target 1) + endif() +endforeach() + +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list2 AND target EQUAL 0) + add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp) + set(target 1) + endif() +endforeach() diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp index ff99bf4641115e466775995de2761ba52897783d..877792d7409f48ee562cded6629f7e1bbde650d9 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp index 917b6b1c3142ec8cf7d1c852ea59494fea1842a9..c3e6ef7d5df2e12de69e2f9465f8f3e3709c0859 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/03_gemm_bias_relu/CMakeLists.txt b/example/03_gemm_bias_relu/CMakeLists.txt index 35c54abac03094f24187df2503aa02b6812c20f3..2f5cba924dfd044a1ff0717d75a8b6e7efc3d16b 100644 --- a/example/03_gemm_bias_relu/CMakeLists.txt +++ b/example/03_gemm_bias_relu/CMakeLists.txt @@ -1 +1,8 @@ -add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp) + set(target 1) + endif() +endforeach() diff --git a/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp b/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp index aee51d05de58b6386a5dd267d45a7e5c12d98276..dffeff23374b6f868fcce952ef8991ca289f59b7 100644 --- a/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp +++ b/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt index c75c5ba51e899e4d0cba6564f52a2abb63854dd8..447fa9871b75906aa4b74f15654181f00c129f9b 100644 --- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -1,17 +1,24 @@ -add_custom_target(example_gemm_add_add_fastgelu_xdl) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_gemm_add_add_fastgelu_xdl) -add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) -add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) -add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) -if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp) -endif(USE_BITINT_EXTENSION_INT4) -add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) + add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) + add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) + add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) + if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp) + endif(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) -add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) -add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) -add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) -if(USE_BITINT_EXTENSION_INT4) - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) -endif(USE_BITINT_EXTENSION_INT4) -add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8) + add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) + add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) + add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) + if(USE_BITINT_EXTENSION_INT4) + add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) + endif(USE_BITINT_EXTENSION_INT4) + add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8) + set(target 1) + endif() +endforeach() \ No newline at end of file diff --git a/example/04_gemm_add_add_fastgelu/common.hpp b/example/04_gemm_add_add_fastgelu/common.hpp index 839587c148956cf186e1ba5ac286a1b73797fade..91d17df95fc470ea23011999d0a929b9331d0cc2 100644 --- a/example/04_gemm_add_add_fastgelu/common.hpp +++ b/example/04_gemm_add_add_fastgelu/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp index ba0476b9b9e3bc031cb0ea61eda66970c2544f39..e630f6783713517977f63a9a2a1ef20836af2c61 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp index b940bfd89737ff4f460ef2b672a99a6696d57ce3..71f6677bae95d1779c76219f6c8b5bc05c3c1d31 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp32.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp32.cpp index fa651a34ea86298a60dbea6158c356b9e50fdfce..4665c3932f419c837125ea3153437302cb9dc022 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp32.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp32.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp index 9f9c423de278b8687816dac1d518f843eb2eb34a..f206bbeb411bb0de0144015d70a892b5587d6f0e 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #error Should compile this file with ck::int4_t support diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int8.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int8.cpp index fadc4ef5ee47737bb33e19f1bab3264f99de5ea2..e46483ab38ac34cb5d4f97602474887d78779327 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int8.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index e0a53005b8c2846b6c3baaee40272c9ad0858188..90104c16308204c4c66282c7950c07485df0df1e 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -1,11 +1,17 @@ -add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) -add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) -add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) -add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) -# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed -add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) + add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) + add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) + add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) + # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed + add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) + set(target 1) + endif() +endforeach() add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp) add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp) add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp) - diff --git a/example/09_convnd_fwd/convnd_fwd_common.hpp b/example/09_convnd_fwd/convnd_fwd_common.hpp index 4c594ccdf817a069918b348d7b23c6cc533bde28..109b8f9ee34096601f0dd485db6840e0674c1937 100644 --- a/example/09_convnd_fwd/convnd_fwd_common.hpp +++ b/example/09_convnd_fwd/convnd_fwd_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/09_convnd_fwd/convnd_fwd_dl_common.hpp b/example/09_convnd_fwd/convnd_fwd_dl_common.hpp index 855710b9d9a30ec47c74a6f31d1b03ade688e28e..aeddd4fc59fc521a196487bf7c0ea4e88bf6e1c8 100644 --- a/example/09_convnd_fwd/convnd_fwd_dl_common.hpp +++ b/example/09_convnd_fwd/convnd_fwd_dl_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/09_convnd_fwd/convnd_fwd_dl_fp16.cpp b/example/09_convnd_fwd/convnd_fwd_dl_fp16.cpp index db5a7f0bc3351c9f1670993942a54668c191e56f..7b6f18f46a19c9f0d7c99093fa180370acfc39c0 100644 --- a/example/09_convnd_fwd/convnd_fwd_dl_fp16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_dl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_dl_common.hpp" diff --git a/example/09_convnd_fwd/convnd_fwd_dl_fp32.cpp b/example/09_convnd_fwd/convnd_fwd_dl_fp32.cpp index 964d784c8592b2e94bc38cc8eab6b7a645d88fc1..551655b17739b6d7250ddc447482c4141948f6fd 100644 --- a/example/09_convnd_fwd/convnd_fwd_dl_fp32.cpp +++ b/example/09_convnd_fwd/convnd_fwd_dl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_dl_common.hpp" diff --git a/example/09_convnd_fwd/convnd_fwd_dl_int8.cpp b/example/09_convnd_fwd/convnd_fwd_dl_int8.cpp index b0cd88f214c8550c018ad9f4d268370dbd816c7f..27a3f2e2a9eacc7fadfce307d44f3ee9fd856585 100644 --- a/example/09_convnd_fwd/convnd_fwd_dl_int8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_dl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_dl_common.hpp" diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp index d55d3154916b0f4ba2fee195b7957904ae3d3139..74cf91d16017bd24096dac6348e5a8815856a51a 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_common.hpp" diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp index d84afba6426b1abe4b10a3b4da1fefe7c2e9272a..f6d69bafd48a4b8bff968c2fb8600793b1310294 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_common.hpp" diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp index f5acc540cf98b8876caa0d5c0812436dd200deb8..6c3171f6157ac4a83a488af12089a97accfd2c7c 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_common.hpp" diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp index 8d697976abd43e5d07c49ec4e0819cbdf15ac02b..9977a496d229876c083e0afe99e17ff0fa5fca05 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_common.hpp" diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp index 99f7f2565c748a950841dffa12a1d343df0051ab..bf084b3cc0b6ceab326b2b21e963e0ce2d4bdcad 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_common.hpp" diff --git a/example/09_convnd_fwd/run_convnd_fwd_dl_example.inc b/example/09_convnd_fwd/run_convnd_fwd_dl_example.inc index 697ada14ba960f7231db3aa0d5b2482ca38578b8..6474df1c355bc059af5429f96616b2192f46f292 100644 --- a/example/09_convnd_fwd/run_convnd_fwd_dl_example.inc +++ b/example/09_convnd_fwd/run_convnd_fwd_dl_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/09_convnd_fwd/run_convnd_fwd_example.inc b/example/09_convnd_fwd/run_convnd_fwd_example.inc index 36a68056f1d7a34b4309a06941f0ff1f477eaba8..49852ff6678f7c3fbbfce4ac8b1303876092cf40 100644 --- a/example/09_convnd_fwd/run_convnd_fwd_example.inc +++ b/example/09_convnd_fwd/run_convnd_fwd_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt index 98941b4db53b022388023e5a0eba10c99b5881be..9577b4569a38b4f4fa5d7d859f0f336920bb3100 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt @@ -1,16 +1,20 @@ -add_custom_target(example_convnd_fwd_reduce_xdl) - -add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) -add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) -add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) -add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) - -add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) -add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) -add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) -add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) - -if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) - add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) -endif(USE_BITINT_EXTENSION_INT4) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_fwd_reduce_xdl) + add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) + add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) + add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) + add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) + add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) + add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) + add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) + add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) + if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) + add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) + endif(USE_BITINT_EXTENSION_INT4) + set(target 1) + endif() +endforeach() \ No newline at end of file diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp index 00e370f2968df0409155d5407066e863574613ce..137b0d1ff0fb0a7cb7159d52f54b4cd3496629ae 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_bf16.cpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_bf16.cpp index 6ff29b4b0ff0f11d6c9d3becbd3aaafdccf3020e..4ccacb0bcee93e72eb1fd23203e9404f1c23c478 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_bf16.cpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp16.cpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp16.cpp index 02c19c2b63bf4f9ee663140b414c89f6221cff51..bf495725e8e77a37019d1d858bf3fe3b69e828d1 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp16.cpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp32.cpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp32.cpp index 679bb5c0c45a1e2e74fc56d21ee6f7f73cae76ad..5848785673c340149d622fd782671d5fb174edab 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp32.cpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp index abdbdaf74d5c0bd2b87da17d315e8712cd896d52..bf7127502faac94e6858d37c389f316b7ead35a6 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #error Should compile this file with ck::int4_t support diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int8.cpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int8.cpp index cf86afa8e94957c01fc92acb5ed2286fcb52466c..3e1694cbe8cae679ef249b6654efc40f8b1fc1ab 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int8.cpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc index b3a3891781766acd824bc1974b88f57fdd85b711..cebfeb51d63eac30243d8e3b0468b821f6fc1eb3 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/12_reduce/reduce_blockwise.cpp b/example/12_reduce/reduce_blockwise.cpp index a7ee9990c1941a73b632f3cc1d32b14a00897cc2..9a736d4cfac9d1062c568e795592711e30d18586 100644 --- a/example/12_reduce/reduce_blockwise.cpp +++ b/example/12_reduce/reduce_blockwise.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/12_reduce/reduce_blockwise_impl.hpp b/example/12_reduce/reduce_blockwise_impl.hpp index e6e3cc8d52bc9fe6568d98194e644cf75ad1950a..7f8394a7301cc3da8271e1e8c41482d0b06a2859 100644 --- a/example/12_reduce/reduce_blockwise_impl.hpp +++ b/example/12_reduce/reduce_blockwise_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/12_reduce/reduce_blockwise_two_call.cpp b/example/12_reduce/reduce_blockwise_two_call.cpp index dbb18a0d83f62d27750c7797842af9bea0eb1383..eb8b5c76d31a1a0fc7a1a7e085416f80be9abf7f 100644 --- a/example/12_reduce/reduce_blockwise_two_call.cpp +++ b/example/12_reduce/reduce_blockwise_two_call.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/12_reduce/reduce_example_common.hpp b/example/12_reduce/reduce_example_common.hpp index 05f0a0edb25ba520d475522bd2963bc050ccc010..5f9a48804a7268cef5a4a065d0c875565bdb3e9e 100644 --- a/example/12_reduce/reduce_example_common.hpp +++ b/example/12_reduce/reduce_example_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/12_reduce/reduce_multiblock_atomic_add.cpp b/example/12_reduce/reduce_multiblock_atomic_add.cpp index c4d63a3add8b0d8016b11b8834a1385dcca5beb5..120e3f05957fdb9f937b290dfe09458db248f9c0 100644 --- a/example/12_reduce/reduce_multiblock_atomic_add.cpp +++ b/example/12_reduce/reduce_multiblock_atomic_add.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/12_reduce/reduce_multiblock_atomic_add_impl.hpp b/example/12_reduce/reduce_multiblock_atomic_add_impl.hpp index 905242fb6b5bca3c0dd7209a5d2827546036695e..fed62186448d58ce58c2a13c6dd184e205801185 100644 --- a/example/12_reduce/reduce_multiblock_atomic_add_impl.hpp +++ b/example/12_reduce/reduce_multiblock_atomic_add_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/13_pool2d_fwd/pool2d_fwd_common.hpp b/example/13_pool2d_fwd/pool2d_fwd_common.hpp index b83cb6a96f0b10f0aeff5e57d386f2ac1c3fa4e4..1157ccd38705b92c759a6de1181fe2b135425091 100644 --- a/example/13_pool2d_fwd/pool2d_fwd_common.hpp +++ b/example/13_pool2d_fwd/pool2d_fwd_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -17,115 +17,11 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp" template -static void pool_host_verify(const Tensor& in, - Tensor& out, - Tensor& out_indices, - const std::array& window_spatial_lengths, - const std::array& window_strides, - const std::array& in_left_pads, - const std::array& /*in_right_pads*/) -{ - const int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1]; - - using ReduceOperation = typename ck::reduce_binary_operator::opType; - - auto elementwise_ops = - ck::reduce_unary_operator::GetElementwiseOperator(reduceLength); - - auto in_elementwise_op = std::get<0>(elementwise_ops); - auto acc_elementwise_op = std::get<1>(elementwise_ops); - - if constexpr(!OutputIndex) - { - using Accumulation = - ck::detail::AccumulateWithNanCheck; - - auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { - auto accuVal = ReduceOperation::template GetIdentityValue(); - - for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y) - { - ck::index_t hi = ho * window_strides[0] + y - in_left_pads[0]; - for(ck::index_t x = 0; x < window_spatial_lengths[1]; ++x) - { - ck::index_t wi = wo * window_strides[1] + x - in_left_pads[1]; - if(hi >= 0 && hi < static_cast(in.mDesc.GetLengths()[2]) && - wi >= 0 && wi < static_cast(in.mDesc.GetLengths()[3])) - { - AccDataType currVal = static_cast(in(n, c, hi, wi)); - - in_elementwise_op(currVal, currVal); - - Accumulation::Calculate(accuVal, currVal); - } - } - } - - acc_elementwise_op(accuVal, accuVal); - - out(n, c, ho, wo) = accuVal; - }; - - make_ParallelTensorFunctor(f_nchw, - out.mDesc.GetLengths()[0], - out.mDesc.GetLengths()[1], - out.mDesc.GetLengths()[2], - out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else - { - using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck; - auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { - auto accuVal = ReduceOperation::template GetIdentityValue(); - IndexDataType accuIndex = 0; - - for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y) - { - ck::index_t hi = ho * window_strides[0] + y - in_left_pads[0]; - for(ck::index_t x = 0; x < window_spatial_lengths[1]; ++x) - { - ck::index_t wi = wo * window_strides[1] + x - in_left_pads[1]; - if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && - wi < in.mDesc.GetLengths()[3]) - { - AccDataType currVal = static_cast(in(n, c, hi, wi)); - IndexDataType currIndex = y * window_spatial_lengths[1] + x; - - in_elementwise_op(currVal, currVal); - - Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex); - } - } - } - - acc_elementwise_op(accuVal, accuVal); - - out(n, c, ho, wo) = accuVal; - out_indices(n, c, ho, wo) = accuIndex; - }; - - make_ParallelTensorFunctor(f_nchw, - out.mDesc.GetLengths()[0], - out.mDesc.GetLengths()[1], - out.mDesc.GetLengths()[2], - out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - }; -} - -template window_spatial_lengths{{Y, X}}; - const std::array window_strides{{window_stride_h, window_stride_w}}; - const std::array input_left_pads{{in_left_pad_h, in_left_pad_w}}; - const std::array input_right_pads{{in_right_pad_h, in_right_pad_w}}; + const std::vector window_spatial_lengths{Y, X}; + const std::vector window_strides{window_stride_h, window_stride_w}; + const std::vector input_left_pads{in_left_pad_h, in_left_pad_w}; + const std::vector input_right_pads{in_right_pad_h, in_right_pad_w}; // tensor layout auto f_host_tensor_descriptor = @@ -219,14 +116,16 @@ bool pool_test(bool do_verification, static_cast(in_device_buf.GetDeviceBuffer()), static_cast(out_device_buf.GetDeviceBuffer()), static_cast(out_indices_device_buf.GetDeviceBuffer()), - N, - C, - std::array{{Hi, Wi}}, - std::array{{Y, X}}, - std::array{{Ho, Wo}}, + {N, C, Hi, Wi}, + {Y, X}, + {N, C, Ho, Wo}, + {C * Hi * Wi, 1, Wi * C, C}, + {C * Ho * Wo, 1, Wo * C, C}, + {C * Ho * Wo, 1, Wo * C, C}, window_strides, input_left_pads, - input_right_pads); + input_right_pads, + {2, 3}); if(!pool.IsSupportedArgument(argument_ptr.get())) { @@ -252,19 +151,28 @@ bool pool_test(bool do_verification, if(do_verification) { - pool_host_verify(in_n_c_hi_wi, - out_n_c_ho_wo_host, - out_indices_n_c_ho_wo_host, - window_spatial_lengths, - window_strides, - input_left_pads, - input_right_pads); + using ReferencePoolingFwdInstance = + ck::tensor_operation::host::ReferencePoolingFwd<4, + 2, + InDataType, + OutDataType, + ComputeDataType, + IndexDataType, + ReduceOpId, + PropagateNan, + OutputIndex>; + + auto ref_pooling = ReferencePoolingFwdInstance{}; + auto ref_pooling_invoker = ref_pooling.MakeInvoker(); + auto ref_pooling_argument = ref_pooling.MakeArgument(in_n_c_hi_wi, + out_n_c_ho_wo_host, + out_indices_n_c_ho_wo_host, + window_spatial_lengths, + window_strides, + input_left_pads, + input_right_pads); + + ref_pooling_invoker.Run(ref_pooling_argument); out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data()); diff --git a/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp b/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp index 659f3251dcf9f8d7661cf44169458937f0d15088..daf8540d4936806904b8d2630f60fa9c094b1b79 100644 --- a/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp +++ b/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp @@ -1,8 +1,7 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include -#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -10,9 +9,9 @@ #include "pool2d_fwd_common.hpp" -using InDataType = ck::half_t; -using OutDataType = ck::half_t; -using AccDataType = float; +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using ComputeDataType = float; using IndexDataType = int32_t; @@ -91,7 +90,7 @@ int main(int argc, char* argv[]) bool pass = pool_test -#include #include "ck/ck.hpp" #include "ck/utility/reduction_enums.hpp" @@ -10,9 +9,9 @@ #include "pool2d_fwd_common.hpp" -using InDataType = float; -using OutDataType = float; -using AccDataType = float; +using InDataType = float; +using OutDataType = float; +using ComputeDataType = float; using IndexDataType = int32_t; @@ -91,7 +90,7 @@ int main(int argc, char* argv[]) bool pass = pool_test #include diff --git a/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp b/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp index d5f4e6f62c3eb8e58285e81446a5e1332c37c533..aa3e0116954e603ed5cd759820ea49184bb4cb8a 100644 --- a/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp +++ b/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp b/example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp index 2371737382447d666d55633fbf3ac976c3017cf7..4b207df5c628976c54d8b9c0f53ea140d51c98e5 100644 --- a/example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp +++ b/example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index e3080c4a9ae0aafd9cd93072503a6eaeeee68495..9df256c38d59f5ae48ca9539b5a3d637c9446c8f 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -5,6 +5,7 @@ add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp) add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp) +add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp) add_dependencies(example_grouped_gemm_xdl @@ -12,7 +13,8 @@ add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp16 example_grouped_gemm_xdl_bfp16 example_grouped_gemm_xdl_int8 - example_grouped_gemm_multiple_d_dl_fp16) + example_grouped_gemm_multiple_d_dl_fp16 + example_grouped_gemm_xdl_splitk_fp16) if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_dl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_dl_fp16.cpp index a5c51ceb0cfbc1a19215c8a2313b12a3b4d5ec51..3e1f7f089371b17fb02ebca53b27b0860c7b9c72 100644 --- a/example/15_grouped_gemm/grouped_gemm_multiple_d_dl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_dl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp index 05d572a1f532446bdc0ec43e59b1eeadfe4bd054..680cee1f814db7cef177e87e4d9c03ba745b8e26 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp index 3f78dafa8977b1b57592500138dd8f034b294efd..90a12bc1ddbf27e04a55de23e3fc6845868dac43 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp index fd93bb5f87d49a96da8973c3f862b8b48bdeb180..28b0fcd0cea6d6aed46ffeabc35c4c5ebc72b864 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_int4.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_int4.cpp index faf41bbf0bba8e06fee9982a2f158706401ca45d..60c4a71a35146c31d9b79703faf5d3422942d817 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_int4.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp index 7cb09778c521c9dfaed38c39ee5cb7fcff7bbc26..0c96ef56d3c144f19dcb4c7101a8d481d8960ee6 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..743ab96be6291e617cb012957865a6897bd61a91 --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +#include "run_grouped_gemm_example.inc" + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + problem_size.Ms = { + 167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ns.push_back(768); + problem_size.Ks.push_back(4608); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + exit(0); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 324e1772805ac2f1b9b84872d3b8baec3c39c6b5..320870e0de7cab9dee43cab6542f13ca5341a90b 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -147,6 +147,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co #else a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); + c_tensors_device[i]->SetZero(); #endif p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); diff --git a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt index 226656a736637beea41758bf59dd1bd27ed8189f..a42b427c6152e90a0f04b5cc8783927c32dae844 100644 --- a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt +++ b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt @@ -1,40 +1,47 @@ -add_custom_target(example_gemm_reduce_xdl) -add_custom_target(example_gemm_reduce_xdl_max) -add_custom_target(example_gemm_reduce_xdl_mean_meansquare) -add_custom_target(example_gemm_add_add_mean_meansquare_xdl) - -add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) -add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) -add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) -add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) - -add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) - -add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) -add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) -add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) -add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) - -add_dependencies(example_gemm_reduce_xdl_max +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_gemm_reduce_xdl) + add_custom_target(example_gemm_reduce_xdl_max) + add_custom_target(example_gemm_reduce_xdl_mean_meansquare) + add_custom_target(example_gemm_add_add_mean_meansquare_xdl) + + add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) + add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) + add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) + add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) + + add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) + + add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) + add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) + add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) + add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) + + add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16 example_gemm_max_xdl_fp16 example_gemm_max_xdl_fp32 example_gemm_max_xdl_int8) -add_dependencies(example_gemm_reduce_xdl_mean_meansquare + add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16 example_gemm_mean_meansquare_xdl_fp32 example_gemm_mean_meansquare_xdl_bf16 example_gemm_add_addsquare_xdl_int8) -add_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) + add_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) -add_dependencies(example_gemm_reduce_xdl + add_dependencies(example_gemm_reduce_xdl example_gemm_reduce_xdl_mean_meansquare example_gemm_reduce_xdl_max example_gemm_add_add_mean_meansquare_xdl) -if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp) - add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4) -endif() + if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp) + add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4) + endif() + set(target 1) + endif() +endforeach() diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp index eb3832a668ff9829ee76e87e6ea8988b40e6980b..2f6533d4481a2f339417a19f0f290f7fd306ee9c 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp index e1248002f751abce96072ee64283ca00022a8713..b28e7f85d3137915a2fe438b16aef5447b2d4d24 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp index c2feffeb8953df030ad2e30cb942bec529b2cac1..b30ce2c48ad32d99dc5c8a4aaf7a3ee9ea774baf 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp index 363390add3e99c96ef3891bec995cda5f8c3e24b..31e2efd6f635e30e569d97f73b4a5967b4909b82 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp index de6b7eb480b30d4347018da94ec518ff29fdbeec..d3c7c1d99c06a222bcefae42d94cdbb6fc8e1239 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int4.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int4.cpp index 9666fc6622cdf3ab7cbab7071ebffd69d18b7668..9a4a6bc6e11a657b9d1eeaa35dd92552e397130d 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int4.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int8.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int8.cpp index 00e0b767a45dd979868b81635a79600d4695c779..1a8457a8bf87cdc82e687246014d2cba73065b8f 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int8.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp index 652c0e6ea6d2bbe3321de2deefb2c186d7252023..5c2706c79ace02e72f2956270c322c4b00782ebf 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp index 7eee24fed83988698e0e6d46f830f3a9dfdb0257..c119e243702a27df9bba2d6cf3f791f523c98d7d 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp index c250b996928dd406bde9253ccbf40e6cd9d77347..0f5e588383abf30064d27e452131ceef1a7ad828 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp b/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp index 62992de59765d3e05d935ac69894c4550a7ef472..1bea1bcf3e0fc1494e07952020db175a1d91bfbf 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/17_convnd_bwd_data/CMakeLists.txt b/example/17_convnd_bwd_data/CMakeLists.txt index fa4e65d965f16c7049a21dba93e881fa04cd5319..8ab9f37bb360edd0449f31b13d4be958341317dd 100644 --- a/example/17_convnd_bwd_data/CMakeLists.txt +++ b/example/17_convnd_bwd_data/CMakeLists.txt @@ -1,5 +1,11 @@ -add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp) -target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility) - +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp) + target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility) + set(target 1) + endif() +endforeach() add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp) target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility) diff --git a/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp b/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp index 26fa9e9821fc050c9ab69fe1a170b4c72a98dcd3..4a9d16c5c303e43989c1f32e51c2cbce6f279e5d 100644 --- a/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp +++ b/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -81,32 +81,33 @@ int run_conv_bwd_data(bool do_verification, in_device_buf.SetZero(); // do GEMM - auto conv = DeviceConvNdBwdDataInstance{}; - auto invoker = conv.MakeInvoker(); - auto argument = conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - conv_param.N_, - conv_param.K_, - conv_param.C_, - conv_param.input_spatial_lengths_, - conv_param.filter_spatial_lengths_, - conv_param.GetOutputSpatialLengths(), - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_, - in_element_op, - wei_element_op, - out_element_op); - - if(!conv.IsSupportedArgument(argument)) + auto conv = DeviceConvNdBwdDataInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = + conv.MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + conv_param.N_, + conv_param.K_, + conv_param.C_, + conv_param.input_spatial_lengths_, + conv_param.filter_spatial_lengths_, + conv_param.GetOutputSpatialLengths(), + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument.get())) { std::cout << "Not support,please check parameters or device"; return 0; } - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + float ave_time = invoker.Run(argument.get(), StreamConfig{nullptr, time_kernel}); std::size_t flop = conv_param.GetFlops(); std::size_t num_btype = conv_param.GetByte(); diff --git a/example/17_convnd_bwd_data/convnd_bwd_data_dl_fp16.cpp b/example/17_convnd_bwd_data/convnd_bwd_data_dl_fp16.cpp index f0896e977144867421b857afd0a58a239716712b..6b84eaba471a966c89b62d9f3be32d9b0e0834d6 100644 --- a/example/17_convnd_bwd_data/convnd_bwd_data_dl_fp16.cpp +++ b/example/17_convnd_bwd_data/convnd_bwd_data_dl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_bwd_data_common.hpp" diff --git a/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp16.cpp b/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp16.cpp index c4f2c1f02bb9010ab3f5d97166819732e911b68a..c9989c60ac2c74e6821bdf6dd40e703998accf01 100644 --- a/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp16.cpp +++ b/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_bwd_data_common.hpp" diff --git a/example/18_batched_gemm_reduce/CMakeLists.txt b/example/18_batched_gemm_reduce/CMakeLists.txt index 99fc0043d2803494522b446cabf4a3f98ca12a5a..94ed129dc03664043b1a0295f9f1dcfb38a77002 100644 --- a/example/18_batched_gemm_reduce/CMakeLists.txt +++ b/example/18_batched_gemm_reduce/CMakeLists.txt @@ -1,2 +1,8 @@ -add_example_executable(example_batched_gemm_reduce_xdl_fp16 batched_gemm_reduce_xdl_fp16.cpp) - +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_example_executable(example_batched_gemm_reduce_xdl_fp16 batched_gemm_reduce_xdl_fp16.cpp) + set(target 1) + endif() +endforeach() diff --git a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp index c2e3602a7bb22e3f31c66bb6045171229584fa14..e363dc5c12dd84a36a67e969b06d29a179679f94 100644 --- a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp +++ b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp b/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp index bee5dea546f6da3cc23aa539583915c9b8a3b21e..24c8d82f674d97afae8cb1b3dc0274c1c13daf35 100644 --- a/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp +++ b/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp b/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp index 6fc63b899e5c6abc91cad26c92c7d7733f0e8f5a..3c04c561403d0ecfb0819aa569bdf24511053442 100644 --- a/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp +++ b/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/19_binary_elementwise/elementwise_add_1d.cpp b/example/19_binary_elementwise/elementwise_add_1d.cpp index a5a6bc0a8bedee0a16906381c91a9c93b070e572..1ac09641a1e20a6be8bf44f4f1ff0f3be045dedb 100644 --- a/example/19_binary_elementwise/elementwise_add_1d.cpp +++ b/example/19_binary_elementwise/elementwise_add_1d.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/19_binary_elementwise/elementwise_add_4d.cpp b/example/19_binary_elementwise/elementwise_add_4d.cpp index cc209b12e3df12413beabfd11961671f4e303b41..e571aa8468008a5f4e7097eee74763a583faf9af 100644 --- a/example/19_binary_elementwise/elementwise_add_4d.cpp +++ b/example/19_binary_elementwise/elementwise_add_4d.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/20_grouped_conv_bwd_weight/CMakeLists.txt b/example/20_grouped_conv_bwd_weight/CMakeLists.txt index cbe4f5f4869d77c7b176980193cd03403c3d5c1b..db170deccee8d03772bb280b334f33483f6c7a7f 100644 --- a/example/20_grouped_conv_bwd_weight/CMakeLists.txt +++ b/example/20_grouped_conv_bwd_weight/CMakeLists.txt @@ -1,11 +1,17 @@ -add_custom_target(example_grouped_conv_bwd_weight) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_grouped_conv_bwd_weight) -add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) -add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) + add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) + add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) - -add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16 + add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16 example_grouped_conv_bwd_weight_xdl_bf16) + set(target 1) + endif() +endforeach() add_custom_target(example_grouped_conv_bwd_weight_dl) diff --git a/example/20_grouped_conv_bwd_weight/common.hpp b/example/20_grouped_conv_bwd_weight/common.hpp index 3f4818d2e3336799c3dfb5cca4fb14d9428ab58a..15727495f0f1f8dfedcb24a0ffa02d3d7aea67e0 100644 --- a/example/20_grouped_conv_bwd_weight/common.hpp +++ b/example/20_grouped_conv_bwd_weight/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp index aed6d22b023aafc35e60fac29b6375a8602b87d5..3cd70d0f36fb12a33e05a48545fe0aaa5f0f34bb 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp index 4a2a6195d9574c26e3526d570bfddff067bd3e37..966b5867672b2a641210fb275688f2f19c70bb8e 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc index 7891812375f060768558556a7025bbe38458c379..39b9100bf8b34f26055798d745cb3cf422e53d70 100644 --- a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc +++ b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. template using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight 1 // TODO: Add Dl op split_k > 1 support - if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030")) + if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || + ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || + ck::get_device_name() == "gfx1102")) { split_k = 2; } diff --git a/example/21_gemm_layernorm/CMakeLists.txt b/example/21_gemm_layernorm/CMakeLists.txt index 2eb7052e1eea21ea652f088813618b15256fca10..ff870e7c6b15bca0c9c91c730020fedd0e61546c 100644 --- a/example/21_gemm_layernorm/CMakeLists.txt +++ b/example/21_gemm_layernorm/CMakeLists.txt @@ -1,4 +1,11 @@ -add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp) -add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp) -add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp) -add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp) + add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp) + add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp) + add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp) + set(target 1) + endif() +endforeach() diff --git a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp index 192fe87b626ff721ca106b9c633c109483f5772a..96d04dcb37798fb6db3334bd2d802e435fdc93fd 100644 --- a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp index 3f01e6947728ecbef57734190b3896ff53899b02..fc58ca19f8673aa0b2205df21350f7ddf0b92196 100644 --- a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp b/example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp index 4da6da65f7ab21b6ffae367747ccf5f0c71232dc..bd1d6932aceb72890dcb0f7c35004cc0dccd93af 100644 --- a/example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp b/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp index e7d857c4a0fa53a262243b759ddf9b33c691e26c..90d80f9f034b391f75c498f3a34232edf64f5260 100644 --- a/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/22_cgemm/cgemm_xdl_bf16.cpp b/example/22_cgemm/cgemm_xdl_bf16.cpp index 92ed90ce4ab3ec09da2e053ceebe4d4ed6bcf36d..fa4482a984f20d203bd8bee68614c8e71ecbf83c 100644 --- a/example/22_cgemm/cgemm_xdl_bf16.cpp +++ b/example/22_cgemm/cgemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/example/22_cgemm/cgemm_xdl_common.hpp b/example/22_cgemm/cgemm_xdl_common.hpp index 6aa06b7c32cb476b61a88235e13a12a3d1a15db5..26137a7c2e50d0d5010639592c0a614e2eca607b 100644 --- a/example/22_cgemm/cgemm_xdl_common.hpp +++ b/example/22_cgemm/cgemm_xdl_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/22_cgemm/cgemm_xdl_fp16.cpp b/example/22_cgemm/cgemm_xdl_fp16.cpp index 11373736ee8b37efb5f4082253d46f02074f011f..89a581e865a56b231711352f9733403ed2945aea 100644 --- a/example/22_cgemm/cgemm_xdl_fp16.cpp +++ b/example/22_cgemm/cgemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/example/22_cgemm/cgemm_xdl_fp32.cpp b/example/22_cgemm/cgemm_xdl_fp32.cpp index 0f45c18c4818726f179adda70dc12b5ea7c45b9d..cf9659959990823f10ef0cb30fa0eef928c32adc 100644 --- a/example/22_cgemm/cgemm_xdl_fp32.cpp +++ b/example/22_cgemm/cgemm_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/example/22_cgemm/cgemm_xdl_int4.cpp b/example/22_cgemm/cgemm_xdl_int4.cpp index c26a83baafd375b477845fbdd793c3918d7b7dc5..f69cc2b3cc4f98e94ddbf0b37e8ac7324515a8d6 100644 --- a/example/22_cgemm/cgemm_xdl_int4.cpp +++ b/example/22_cgemm/cgemm_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/example/22_cgemm/cgemm_xdl_int8.cpp b/example/22_cgemm/cgemm_xdl_int8.cpp index 2f24189861d8468848d28b9c0480da7c3f2c4fdd..c4835b853ee75496f963a8e14d7a0190eb4cb1cf 100644 --- a/example/22_cgemm/cgemm_xdl_int8.cpp +++ b/example/22_cgemm/cgemm_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/example/23_softmax/softmax_blockwise.cpp b/example/23_softmax/softmax_blockwise.cpp index 41afd72f5ac9ddc3dc030a914e1a26da60c27356..d09e434bcfbb262a83c5167429d41d1ca54391ef 100644 --- a/example/23_softmax/softmax_blockwise.cpp +++ b/example/23_softmax/softmax_blockwise.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp index c934d35019602e634da6f8b6d49c1c3b133c6800..420a7cf74f3186ac62d5dc37346a202178d8d273 100644 --- a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp +++ b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp index 98835f98fa6ecea9c38f1155f5dd03d0787fa0d2..9d606db205dde86d544971b4b7fc4830ba73c568 100644 --- a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp +++ b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/26_contraction/contraction_bilinear_xdl_fp32.cpp b/example/26_contraction/contraction_bilinear_xdl_fp32.cpp index ea105e4ff2bb9a33830fd1182007d658c52ba7e0..78522160c85a1024ec510197e5ab1069e6f077cf 100644 --- a/example/26_contraction/contraction_bilinear_xdl_fp32.cpp +++ b/example/26_contraction/contraction_bilinear_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -16,6 +16,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/numeric.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" template using S = ck::Sequence; @@ -74,141 +75,6 @@ using DeviceOpInstanceMNNN = ck::tensor_operation::device:: using DeviceOpInstance = DeviceOpInstanceKKNN; -// hardcoded for NumDimM == NumDimN == NumDimK == 2 -template = false> -struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator -{ - // Argument - struct Argument : public ck::tensor_operation::device::BaseArgument - { - Argument(const Tensor& a_ms_ks, - const Tensor& b_ns_ks, - Tensor& e_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - : a_ms_ks_{a_ms_ks}, - b_ns_ks_{b_ns_ks}, - e_ms_ns_{e_ms_ns}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - cde_element_op_{cde_element_op} - { - } - - const Tensor& a_ms_ks_; - const Tensor& b_ns_ks_; - Tensor& e_ms_ns_; - - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation cde_element_op_; - }; - - // Invoker - struct Invoker : public ck::tensor_operation::device::BaseInvoker - { - using Argument = ReferenceContraction_M2_N2_K2::Argument; - - float Run(const Argument& arg) - { - auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) { - const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2]; - const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3]; - - AccDataType v_acc = 0; - - for(int k0 = 0; k0 < K0; ++k0) - { - for(int k1 = 0; k1 < K1; ++k1) - { - AccDataType v_a; - AccDataType v_b; - - arg.a_element_op_( - v_a, ck::type_convert(arg.a_ms_ks_(m0, m1, k0, k1))); - arg.b_element_op_( - v_b, ck::type_convert(arg.b_ns_ks_(n0, n1, k0, k1))); - - v_acc += v_a * v_b; - } - } - - AccDataType v_c; - - arg.cde_element_op_(v_c, v_acc); - - arg.e_ms_ns_(m0, m1, n0, n1) = v_c; - }; - - make_ParallelTensorFunctor(f_ms_ns, - arg.e_ms_ns_.mDesc.GetLengths()[0], - arg.e_ms_ns_.mDesc.GetLengths()[1], - arg.e_ms_ns_.mDesc.GetLengths()[2], - arg.e_ms_ns_.mDesc.GetLengths()[3])( - std::thread::hardware_concurrency()); - - return 0; - } - - float Run(const ck::tensor_operation::device::BaseArgument* p_arg, - const StreamConfig& /* stream_config */ = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg)); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override - { - return true; - } - - static auto MakeArgument(const Tensor& a_ms_ks, - const Tensor& b_ns_ks, - Tensor& e_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - { - return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - virtual std::unique_ptr MakeInvokerPointer() - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "ReferenceContraction_M2_N2_K2" - << std::endl; - // clang-format on - - return str.str(); - } -}; - int main(int argc, char* argv[]) { bool do_verification = true; @@ -385,22 +251,22 @@ int main(int argc, char* argv[]) { Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - using ReferenceOpInstance = ReferenceContraction_M2_N2_K2; - - auto ref_gemm = ReferenceOpInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument( - a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + auto ref_argument = + ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); ref_invoker.Run(ref_argument); diff --git a/example/26_contraction/contraction_bilinear_xdl_fp64.cpp b/example/26_contraction/contraction_bilinear_xdl_fp64.cpp index 9a000377bba2b92fe7509fc07821faae3634b5c0..6cceed5bc11ed8eb79568a0397bbf7e1b93fd33d 100644 --- a/example/26_contraction/contraction_bilinear_xdl_fp64.cpp +++ b/example/26_contraction/contraction_bilinear_xdl_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -16,6 +16,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/numeric.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" template using S = ck::Sequence; @@ -74,141 +75,6 @@ using DeviceOpInstanceMNNN = ck::tensor_operation::device:: using DeviceOpInstance = DeviceOpInstanceKKNN; -// hardcoded for NumDimM == NumDimN == NumDimK == 2 -template = false> -struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator -{ - // Argument - struct Argument : public ck::tensor_operation::device::BaseArgument - { - Argument(const Tensor& a_ms_ks, - const Tensor& b_ns_ks, - Tensor& e_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - : a_ms_ks_{a_ms_ks}, - b_ns_ks_{b_ns_ks}, - e_ms_ns_{e_ms_ns}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - cde_element_op_{cde_element_op} - { - } - - const Tensor& a_ms_ks_; - const Tensor& b_ns_ks_; - Tensor& e_ms_ns_; - - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation cde_element_op_; - }; - - // Invoker - struct Invoker : public ck::tensor_operation::device::BaseInvoker - { - using Argument = ReferenceContraction_M2_N2_K2::Argument; - - float Run(const Argument& arg) - { - auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) { - const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2]; - const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3]; - - AccDataType v_acc = 0; - - for(int k0 = 0; k0 < K0; ++k0) - { - for(int k1 = 0; k1 < K1; ++k1) - { - AccDataType v_a; - AccDataType v_b; - - arg.a_element_op_( - v_a, ck::type_convert(arg.a_ms_ks_(m0, m1, k0, k1))); - arg.b_element_op_( - v_b, ck::type_convert(arg.b_ns_ks_(n0, n1, k0, k1))); - - v_acc += v_a * v_b; - } - } - - AccDataType v_c; - - arg.cde_element_op_(v_c, v_acc); - - arg.e_ms_ns_(m0, m1, n0, n1) = v_c; - }; - - make_ParallelTensorFunctor(f_ms_ns, - arg.e_ms_ns_.mDesc.GetLengths()[0], - arg.e_ms_ns_.mDesc.GetLengths()[1], - arg.e_ms_ns_.mDesc.GetLengths()[2], - arg.e_ms_ns_.mDesc.GetLengths()[3])( - std::thread::hardware_concurrency()); - - return 0; - } - - float Run(const ck::tensor_operation::device::BaseArgument* p_arg, - const StreamConfig& /* stream_config */ = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg)); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override - { - return true; - } - - static auto MakeArgument(const Tensor& a_ms_ks, - const Tensor& b_ns_ks, - Tensor& e_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - { - return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - virtual std::unique_ptr MakeInvokerPointer() - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "ReferenceContraction_M2_N2_K2" - << std::endl; - // clang-format on - - return str.str(); - } -}; - int main(int argc, char* argv[]) { bool do_verification = true; @@ -385,22 +251,22 @@ int main(int argc, char* argv[]) { Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - using ReferenceOpInstance = ReferenceContraction_M2_N2_K2; - - auto ref_gemm = ReferenceOpInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument( - a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + auto ref_argument = + ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); ref_invoker.Run(ref_argument); diff --git a/example/26_contraction/contraction_scale_xdl_fp32.cpp b/example/26_contraction/contraction_scale_xdl_fp32.cpp index 26f176b0591d02a821fa783542e9538791e4cbcd..1574f5d18fb7f5f193cd4adb1c1a88616d7b1437 100644 --- a/example/26_contraction/contraction_scale_xdl_fp32.cpp +++ b/example/26_contraction/contraction_scale_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -16,6 +16,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/numeric.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" template using S = ck::Sequence; @@ -73,141 +74,6 @@ using DeviceOpInstanceMNN = ck::tensor_operation::device:: using DeviceOpInstance = DeviceOpInstanceKKN; -// hardcoded for NumDimM == NumDimN == NumDimK == 2 -template = false> -struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator -{ - // Argument - struct Argument : public ck::tensor_operation::device::BaseArgument - { - Argument(const Tensor& a_ms_ks, - const Tensor& b_ns_ks, - Tensor& e_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - : a_ms_ks_{a_ms_ks}, - b_ns_ks_{b_ns_ks}, - e_ms_ns_{e_ms_ns}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - cde_element_op_{cde_element_op} - { - } - - const Tensor& a_ms_ks_; - const Tensor& b_ns_ks_; - Tensor& e_ms_ns_; - - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation cde_element_op_; - }; - - // Invoker - struct Invoker : public ck::tensor_operation::device::BaseInvoker - { - using Argument = ReferenceContraction_M2_N2_K2::Argument; - - float Run(const Argument& arg) - { - auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) { - const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2]; - const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3]; - - AccDataType v_acc = 0; - - for(int k0 = 0; k0 < K0; ++k0) - { - for(int k1 = 0; k1 < K1; ++k1) - { - AccDataType v_a; - AccDataType v_b; - - arg.a_element_op_( - v_a, ck::type_convert(arg.a_ms_ks_(m0, m1, k0, k1))); - arg.b_element_op_( - v_b, ck::type_convert(arg.b_ns_ks_(n0, n1, k0, k1))); - - v_acc += v_a * v_b; - } - } - - AccDataType v_c; - - arg.cde_element_op_(v_c, v_acc); - - arg.e_ms_ns_(m0, m1, n0, n1) = v_c; - }; - - make_ParallelTensorFunctor(f_ms_ns, - arg.e_ms_ns_.mDesc.GetLengths()[0], - arg.e_ms_ns_.mDesc.GetLengths()[1], - arg.e_ms_ns_.mDesc.GetLengths()[2], - arg.e_ms_ns_.mDesc.GetLengths()[3])( - std::thread::hardware_concurrency()); - - return 0; - } - - float Run(const ck::tensor_operation::device::BaseArgument* p_arg, - const StreamConfig& /* stream_config */ = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg)); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override - { - return true; - } - - static auto MakeArgument(const Tensor& a_ms_ks, - const Tensor& b_ns_ks, - Tensor& e_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - { - return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - virtual std::unique_ptr MakeInvokerPointer() - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "ReferenceContraction_M2_N2_K2" - << std::endl; - // clang-format on - - return str.str(); - } -}; - int main(int argc, char* argv[]) { bool do_verification = true; @@ -368,22 +234,23 @@ int main(int argc, char* argv[]) { Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - using ReferenceOpInstance = ReferenceContraction_M2_N2_K2; - - auto ref_gemm = ReferenceOpInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument( - a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + Tensor empty_tensor(std::vector{}, std::vector{}); + auto ref_argument = + ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); ref_invoker.Run(ref_argument); diff --git a/example/26_contraction/contraction_scale_xdl_fp64.cpp b/example/26_contraction/contraction_scale_xdl_fp64.cpp index 38ed60266de9e5a8322cf0fffa5fcb9eb33cbc84..3dacc708877d838dd5fb0113b5d560885b0551ae 100644 --- a/example/26_contraction/contraction_scale_xdl_fp64.cpp +++ b/example/26_contraction/contraction_scale_xdl_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -16,6 +16,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/numeric.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" template using S = ck::Sequence; @@ -73,141 +74,6 @@ using DeviceOpInstanceMNN = ck::tensor_operation::device:: using DeviceOpInstance = DeviceOpInstanceKKN; -// hardcoded for NumDimM == NumDimN == NumDimK == 2 -template = false> -struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator -{ - // Argument - struct Argument : public ck::tensor_operation::device::BaseArgument - { - Argument(const Tensor& a_ms_ks, - const Tensor& b_ns_ks, - Tensor& e_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - : a_ms_ks_{a_ms_ks}, - b_ns_ks_{b_ns_ks}, - e_ms_ns_{e_ms_ns}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - cde_element_op_{cde_element_op} - { - } - - const Tensor& a_ms_ks_; - const Tensor& b_ns_ks_; - Tensor& e_ms_ns_; - - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation cde_element_op_; - }; - - // Invoker - struct Invoker : public ck::tensor_operation::device::BaseInvoker - { - using Argument = ReferenceContraction_M2_N2_K2::Argument; - - float Run(const Argument& arg) - { - auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) { - const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2]; - const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3]; - - AccDataType v_acc = 0; - - for(int k0 = 0; k0 < K0; ++k0) - { - for(int k1 = 0; k1 < K1; ++k1) - { - AccDataType v_a; - AccDataType v_b; - - arg.a_element_op_( - v_a, ck::type_convert(arg.a_ms_ks_(m0, m1, k0, k1))); - arg.b_element_op_( - v_b, ck::type_convert(arg.b_ns_ks_(n0, n1, k0, k1))); - - v_acc += v_a * v_b; - } - } - - AccDataType v_c; - - arg.cde_element_op_(v_c, v_acc); - - arg.e_ms_ns_(m0, m1, n0, n1) = v_c; - }; - - make_ParallelTensorFunctor(f_ms_ns, - arg.e_ms_ns_.mDesc.GetLengths()[0], - arg.e_ms_ns_.mDesc.GetLengths()[1], - arg.e_ms_ns_.mDesc.GetLengths()[2], - arg.e_ms_ns_.mDesc.GetLengths()[3])( - std::thread::hardware_concurrency()); - - return 0; - } - - float Run(const ck::tensor_operation::device::BaseArgument* p_arg, - const StreamConfig& /* stream_config */ = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg)); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override - { - return true; - } - - static auto MakeArgument(const Tensor& a_ms_ks, - const Tensor& b_ns_ks, - Tensor& e_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - { - return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - virtual std::unique_ptr MakeInvokerPointer() - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "ReferenceContraction_M2_N2_K2" - << std::endl; - // clang-format on - - return str.str(); - } -}; - int main(int argc, char* argv[]) { bool do_verification = true; @@ -368,22 +234,23 @@ int main(int argc, char* argv[]) { Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - using ReferenceOpInstance = ReferenceContraction_M2_N2_K2; - - auto ref_gemm = ReferenceOpInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument( - a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + Tensor empty_tensor(std::vector{}, std::vector{}); + auto ref_argument = + ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); ref_invoker.Run(ref_argument); diff --git a/example/27_layernorm/CMakeLists.txt b/example/27_layernorm/CMakeLists.txt index d96deae45e49acdc8b86f9973417d6a77e1e7e31..94c23ce77499e0c110301a3d3ae82b0ec119a7c7 100644 --- a/example/27_layernorm/CMakeLists.txt +++ b/example/27_layernorm/CMakeLists.txt @@ -1 +1,2 @@ -add_example_executable(example_layernorm_blockwise layernorm_blockwise.cpp) +add_example_executable(example_layernorm_fp16 layernorm_fp16.cpp) +add_example_executable(example_layernorm_splitk_fp16 layernorm_splitk_fp16.cpp) diff --git a/example/27_layernorm/common.hpp b/example/27_layernorm/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..62a71713df84351fea902cdcf3275786b60f5a23 --- /dev/null +++ b/example/27_layernorm/common.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp" diff --git a/example/27_layernorm/layernorm_fp16.cpp b/example/27_layernorm/layernorm_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bb8b954f0acaf2e9d748104a05561f9cb02fc9a0 --- /dev/null +++ b/example/27_layernorm/layernorm_fp16.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using ComputeDataType = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +using DeviceInstance = + ck::tensor_operation::device::DeviceNormalizationImpl; // OutScalarPerVector +#include "run_layernorm_example.inc" + +int main() { return run_groupnorm_example(); } diff --git a/example/27_layernorm/layernorm_splitk_fp16.cpp b/example/27_layernorm/layernorm_splitk_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e0378d028b343a6a70335cb9163370d1586ebf0a --- /dev/null +++ b/example/27_layernorm/layernorm_splitk_fp16.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using ComputeDataType = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +using DeviceInstance = + ck::tensor_operation::device::DeviceNormalizationSplitKImpl; // YScalarPerVector + +#include "run_layernorm_example.inc" + +int main() { return run_groupnorm_example(); } diff --git a/example/27_layernorm/layernorm_blockwise.cpp b/example/27_layernorm/run_layernorm_example.inc similarity index 57% rename from example/27_layernorm/layernorm_blockwise.cpp rename to example/27_layernorm/run_layernorm_example.inc index 7d91b69d04767e00a3cba38f60f43153bdb0d036..95200b540aa9f9704c8fad785b352bdf779c4094 100644 --- a/example/27_layernorm/layernorm_blockwise.cpp +++ b/example/27_layernorm/run_layernorm_example.inc @@ -1,58 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/utility/reduction_enums.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" -#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" - -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_common_util.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp" - -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using ComputeDataType = float; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -constexpr int Rank = 2; -constexpr int NumReduceDim = 1; - -using DeviceInstance = - ck::tensor_operation::device::DeviceNormalizationImpl; // OutScalarPerVector - -int main() +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +int run_groupnorm_example() { bool time_kernel = false; @@ -111,6 +63,10 @@ int main() return 1; }; + size_t workspace_sz = device_instance.GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + device_instance.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + auto invoker_ptr = device_instance.MakeInvokerPointer(); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); @@ -133,7 +89,8 @@ int main() ref_invoker.Run(ref_argument); y_dev.FromDevice(y.mData.data()); - pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results d1", 1e-3, 1e-3); + pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3); } + return (pass ? 0 : 1); } diff --git a/example/28_grouped_gemm_bias_e_permute/grouped_gemm_bias_e_permute_xdl_fp16.cpp b/example/28_grouped_gemm_bias_e_permute/grouped_gemm_bias_e_permute_xdl_fp16.cpp index f8e6501eadaee4e5b4156309afebdf16e11f3e8a..24e9b1d9b7d569adf20caf1bd0f0cddafdbd256b 100644 --- a/example/28_grouped_gemm_bias_e_permute/grouped_gemm_bias_e_permute_xdl_fp16.cpp +++ b/example/28_grouped_gemm_bias_e_permute/grouped_gemm_bias_e_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp index 30ad38a566309dcfacd52f100099d8a3bf02f77f..62233e535151e6b647e938de04b6225fe487b6c4 100644 --- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp index 25d815b9cdfdddc29c699a1bc51dd5191a5246eb..08158bfc250642e032d20fad940bb8853c71d446 100644 --- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt index 4b0ea4f1579694250bcd04f0c4374728eadbf1ff..1c07538b09c5261cfd3e5d34f68bccd03087e6e0 100644 --- a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt +++ b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt @@ -1,25 +1,36 @@ -add_custom_target(example_grouped_conv_fwd_multiple_d) +list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102) -add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) -add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) -add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) -add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list1 AND target EQUAL 0) + add_custom_target(example_grouped_conv_fwd_multiple_d) -add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) -add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) -add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) -add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) -if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) + add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) + add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) + add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) + add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) -endif() # USE_BITINT_EXTENSION_INT4 + if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) + add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) + endif() # USE_BITINT_EXTENSION_INT4 + add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) + add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16) + set(target 1) + endif() +endforeach() -if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") - add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) -endif() - -add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) - -add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list2 AND target EQUAL 0) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp) + set(target 1) + endif() +endforeach() diff --git a/example/30_grouped_conv_fwd_multiple_d/common.hpp b/example/30_grouped_conv_fwd_multiple_d/common.hpp index e7c6ed9b939abb0e8f593d54d5b45c499db5f2c0..e60ebee6e4fe2e08db432816b75bbbd4c388946c 100644 --- a/example/30_grouped_conv_fwd_multiple_d/common.hpp +++ b/example/30_grouped_conv_fwd_multiple_d/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp b/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp index eb6975a6d812aef66f1caf493dfe82549a0e4ae9..ae769ff1d38306f4feed8d3cb00dde79b3f561eb 100644 --- a/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp +++ b/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp index 9d1d257a2889dcdaefb697faa6bcf6c172c55839..039d25029921491e7e67808e554b8cb3e6eb4745 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common_wmma.hpp" diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..793324970e700748ac4b18e9fc429a24f70c007d --- /dev/null +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common_wmma.hpp" + +// kernel data types +using InKernelDataType = I8; +using WeiKernelDataType = I8; +using AccDataType = I32; +using CShuffleDataType = I8; +using BiasKernelDataType = I8; +using ResidualKernelDataType = I8; +using OutKernelDataType = I8; + +// tensor data types +using InUserDataType = InKernelDataType; +using WeiUserDataType = WeiKernelDataType; +using OutUserDataType = OutKernelDataType; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; + +#include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); } diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp index ee300d073a28b44cf8177f64a7278d5e13ef3130..43c0d57dc2adcae69ecde327137eb204b71fa44d 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp index 5a9df0b1e880c8f3bc954c07cd5de3639644c4e1..40b4132b358d8259146c2987472dc7e9d6c71091 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp index c2906cc9dd1dbe22c8d420ce0929783fb0be0fa6..e05d384f26b40273a6bfe77222b805289a99e631 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp index 3d5a243e6b9282f9d362c82c0758776341df646f..5494563fdd568d1f51ec9ee9042d094096e8fc8a 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #error Should compile this file with ck::int4_t support diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int8.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int8.cpp index eaf680fa438a14ee402e3805c18692ffdcf78c7e..6bf2e8d963c1524bfec0680409580b47e5394bc5 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int8.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_xdl_fp16.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_xdl_fp16.cpp index 6de1daa3d454cf556855322deb4b71198fae5d9b..498eda2442fb737a57d15ec0fb58b71a9049f49e 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_xdl_fp16.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc index 4561156e0bd3a935c49751802e8c3037d8735228..eb242203eaa65ba7b8fca45c5f1e3e1b9d96c409 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. template struct LayoutSetting diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc index 8161b1088ada41e680ea3efd305156d2dbd00c51..360b2c8947b371b1502afb3af73c469f0d8c3636 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. template struct LayoutSetting @@ -74,8 +74,8 @@ using DeviceConvFwdInstance = 8, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferDstScalarPerVector_BK1 true, // BBlockLdsExtraN - 1, - 1, + 4, + 2, S<1, 32, 1, 8>, 8>; diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc index d087c31af5def501635ef173769851e111dec3cb..58ed69182e0622751dc3fa2823fcb7cfa5b279a8 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. template using DeviceConvFwdInstance = diff --git a/example/31_batched_gemm_gemm/CMakeLists.txt b/example/31_batched_gemm_gemm/CMakeLists.txt index d79248251c335c5bc7d5ee2a71c14ea3ac60ac88..83989bcd972c23eec7b814bd09858bf19bb1c3c9 100644 --- a/example/31_batched_gemm_gemm/CMakeLists.txt +++ b/example/31_batched_gemm_gemm/CMakeLists.txt @@ -1,8 +1,19 @@ -add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) -add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) -add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) -add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp) +list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list2 gfx908 gfx90a) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list1 AND target EQUAL 0) + add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) + add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) + add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) -if(USE_BITINT_EXTENSION_INT4) -add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) -endif(USE_BITINT_EXTENSION_INT4) + if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) + endif(USE_BITINT_EXTENSION_INT4) + set(target 1) + endif() +endforeach() + +if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") + add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp) +endif() diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_bf16.cpp b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_bf16.cpp index 74e0e07e62a99e0700213e80a066d7e15128e516..7605d9c4f8368e369e4c29167f90620984b5a6c8 100644 --- a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_bf16.cpp +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp index d5fadb8081e86a83784f6fad7e7405a7808b65b9..33ed04fb3068b1bf83eb128825ccad355c6de4f9 100644 --- a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp32.cpp b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp32.cpp index 0dd4e0914f479eefae604ceba8115432ce1d3803..e0eb193ad0484c62ac6fa5695ac43aa17c509227 100644 --- a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp32.cpp +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp index 1fd93622a1b14a72d8159517a84b9f19d2959723..d166214c3376cd90afd11796a4cb85ea28421861 100644 --- a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int8.cpp b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int8.cpp index 15d98abab7dfa0f1168fd411eeae481054053000..40f87d1f554c4ef2a73ceb32263c3499f733b40a 100644 --- a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int8.cpp +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o diff --git a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc index 7e5f1614bcf6c6133d88f9d8dad63ec4be47c9ac..f329146728dd8df3c38af3d2520ac91d43d12d00 100644 --- a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc +++ b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp index 0eb15653306f0179b1d132ddfcd130af6ba13fd6..1d1566d57561afaa84be2de34b11416b5da0571c 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp index 8f1db577c604418b7b1d546b5de251f5a20054a6..bae88d4b8e6f786219e8d68bd0a34f6036dd0d72 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp index 2ce91a8c6023314d61b354614c81ca74c5f95727..a098ce6675e05308fd9de223e807a0e14504a7f5 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_bf16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_bf16.cpp index 1fd2bf69306f5b35dd4ea2ae484c677c6b3c63fc..ce8caf758842225098baac92876ea6996321ff21 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_bf16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp index f4a8589052f0f149902553884a58cc4c6c79d030..138db14963809cd49294d1cfc2171b038dd10c60 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o diff --git a/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp index e4a71b04313446f75f4cbc73456a66bdbb62c47e..d0eb8fcc341818950bd39b7f91d95ca4f5eb4e03 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o diff --git a/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp index 38b5badc6e4f270e409ed98c492035c5bebe00c3..1d97474d205160e4a34d683f40f86f5e71b0a429 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc index 4e43dbdd8fc5a83486cd052ee3e789e05152861f..27602e2313f7aa197e88e1fabeb39245e2fdf5eb 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc index 0b876af952f65f89e9ae55bea344bfa6f61afd5d..fa76faea84e4551ddf8d0617c132dbe7a6045fb3 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc index ef2acf61f55fa7f108e174d4586c13ae3dc7419d..ea1e2734a684b61a363eb93ef0e2ff933f900ea5 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { diff --git a/example/33_multiple_reduce/dual_reduce_common.hpp b/example/33_multiple_reduce/dual_reduce_common.hpp index 326606752b25394d504cb054cf3026173c66e970..cd21790be6548abd0f8097613ee650d4407fc400 100644 --- a/example/33_multiple_reduce/dual_reduce_common.hpp +++ b/example/33_multiple_reduce/dual_reduce_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/33_multiple_reduce/dual_reduce_multiblock.cpp b/example/33_multiple_reduce/dual_reduce_multiblock.cpp index 9360599ed9e8a7517bc244e2942d6949c1f71e61..198931749b19b8e8d91bf16c1a733801cb27aeba 100644 --- a/example/33_multiple_reduce/dual_reduce_multiblock.cpp +++ b/example/33_multiple_reduce/dual_reduce_multiblock.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/33_multiple_reduce/dual_reduce_threadwise.cpp b/example/33_multiple_reduce/dual_reduce_threadwise.cpp index 56255839e567fa0386ba09b80738f5d0d2abcc23..7609edad3527e8219c2ef6351381d6e53455a67b 100644 --- a/example/33_multiple_reduce/dual_reduce_threadwise.cpp +++ b/example/33_multiple_reduce/dual_reduce_threadwise.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/34_batchnorm/batchnorm_backward_nhwc.cpp b/example/34_batchnorm/batchnorm_backward_nhwc.cpp index a6ca9d150bd918966bae06a63ee4eb9da6be5f3f..3756310fd7dd47b111b9fde743cac844ae2c7669 100644 --- a/example/34_batchnorm/batchnorm_backward_nhwc.cpp +++ b/example/34_batchnorm/batchnorm_backward_nhwc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/34_batchnorm/batchnorm_common.hpp b/example/34_batchnorm/batchnorm_common.hpp index bdbc8ea8b88f40de2f25a2ec8c5a74ab5e38fd74..a1b8d253bf061c02511945b98eac9343c0d1a236 100644 --- a/example/34_batchnorm/batchnorm_common.hpp +++ b/example/34_batchnorm/batchnorm_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp b/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp index dc2984851a02e0020c8c4cfa6d59e61bfe49c593..6a8002025a60b277c587190a95855cc262c33e17 100644 --- a/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp +++ b/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp b/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp index da36d65a2954ee1ac3a94c617a219e6a7c2baf44..d6808181570d424b2d9584110c4bbf6c65f8c5a6 100644 --- a/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp +++ b/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/34_batchnorm/batchnorm_infer_impl.hpp b/example/34_batchnorm/batchnorm_infer_impl.hpp index 15170586b636825325c796cdc20eb70996a24bb7..d0b545b2a31d2fc095eece12ce273db823a5d2ac 100644 --- a/example/34_batchnorm/batchnorm_infer_impl.hpp +++ b/example/34_batchnorm/batchnorm_infer_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/35_splitK_gemm/CMakeLists.txt b/example/35_splitK_gemm/CMakeLists.txt index 794583954674f403a4902a6fc1cb244b803479d5..57ac33fc9a187d23a11560a8a60939ca8f818506 100644 --- a/example/35_splitK_gemm/CMakeLists.txt +++ b/example/35_splitK_gemm/CMakeLists.txt @@ -1,17 +1,23 @@ -add_custom_target(example_splitK_gemm_xdl) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_splitK_gemm_xdl) + add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) + add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) + add_example_executable(example_splitK_gemm_xdl_bfp16 splitK_gemm_xdl_bfp16.cpp) + add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) -add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) -add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) -add_example_executable(example_splitK_gemm_xdl_bfp16 splitK_gemm_xdl_bfp16.cpp) -add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) - -add_dependencies(example_splitK_gemm_xdl + add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32 example_splitK_gemm_xdl_fp16 example_splitK_gemm_xdl_bfp16 example_splitK_gemm_xdl_int8) -if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) - add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) -endif() + if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) + add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) + endif() + set(target 1) + endif() +endforeach() diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp index 7191ecf50ab5252b37855072b4e03f1194df52f6..1dc21a6c23ea69d60f0c51fd70b544899496afe8 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp index efdb315b4e5576aef8d992f6d57c231ac1f83a81..74fb16e15b019e145cdf13549254be1b09fbd8ca 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp index bc2e3d1d52b668a2ea15c9f4ea1db44c4faeb19e..7506f694204b3c3aaa564704eda4f358baf17861 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_int4.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_int4.cpp index 4eb27824628d8f193ca6108a589794041fc2b2fb..7ebf9144082de53d64f3662e579dd6d246266f1d 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_int4.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp index eefdbca6b1ae3e49cc221b3037e46888b4a9a9bd..0fc7a5cc29ccefdee6eef3f229715ee781e65223 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp index f0a0cdf6f13df7081d15ab0099122542ccf6118e..d2337dcda5d30eb57e0e5f764f047cf21314a190 100644 --- a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp +++ b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp index 071e8a7431c057a32ea0061655d81a202774b923..b3d0ab6bf77aa85fa6b48fa503a3f3f0dde1bf12 100644 --- a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp +++ b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Computes C_m_o = Relu(A0[m, k] * B0[n, k] + D00[m, n] + D01[mn]) * B1[n, o] + D1[m, o] diff --git a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt index 9cf960c501ccb279febc8205505ca0edfdeefac4..de48093ac2b598e56afb31078bfecbceb10dc792 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt +++ b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt @@ -1,7 +1,13 @@ -add_custom_target(example_grouped_conv_bwd_data) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_grouped_conv_bwd_data) + add_example_executable(example_grouped_conv_bwd_data_fp16 grouped_conv_bwd_data_fp16.cpp) + add_example_executable(example_grouped_conv_bwd_data_bias_relu_fp16 grouped_conv_bwd_data_bias_relu_fp16.cpp) -add_example_executable(example_grouped_conv_bwd_data_fp16 grouped_conv_bwd_data_fp16.cpp) -add_example_executable(example_grouped_conv_bwd_data_bias_relu_fp16 grouped_conv_bwd_data_bias_relu_fp16.cpp) - -add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_fp16) -add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_fp16) + add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_fp16) + add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_fp16) + set(target 1) + endif() +endforeach() \ No newline at end of file diff --git a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp index d07ee7bdc1c4b999f61a84727bb10fe7ea6f4c15..ca824b1075ced161760892fe6de53767b1079ccc 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_fp16.cpp index 55ea8c3a3109469532d0e5c0566da3eb11e7353d..a3533bb4cc4b4874f0e7ae9f333347913d981ba0 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_fp16.cpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_fp16.cpp index ddf82ec512c61c72a1a2e2323a6645f39f15ce6d..fb688b6f3f15383c2920a12e8902c2d4cf05a109 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_fp16.cpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc index 0afd8bd70da849085db6e19583e933c3d1b91968..0f0b120cbcc578b2fa7f650c37c9a28148dd834a 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc +++ b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. bool run_conv_bwd_data_bias_relu(const ExecutionConfig& config, const ck::utils::conv::ConvParam& conv_params, diff --git a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_example.inc b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_example.inc index e50c98bbe844f6ec438c781b7d5d8e254e57d98f..25678491ce90cdf1ab9455f14e091bd5fcd67431 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_example.inc +++ b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. bool run_conv_bwd_data(const ExecutionConfig& config, const ck::utils::conv::ConvParam& conv_params, diff --git a/example/39_permute/common.hpp b/example/39_permute/common.hpp index ab612cea1794c422c737a02e5bbe7a6728904c40..54f3a788097bebc6e293389cecd03d8ae68d8494 100644 --- a/example/39_permute/common.hpp +++ b/example/39_permute/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/39_permute/permute_1xHxW_fp16.cpp b/example/39_permute/permute_1xHxW_fp16.cpp index d7f9b80544a452ca7ea062d7d6135b92a2bf19ea..7336c3b631bcadd7d23a2d6b8e0b3d3a31787516 100644 --- a/example/39_permute/permute_1xHxW_fp16.cpp +++ b/example/39_permute/permute_1xHxW_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/39_permute/permute_HxWx4_fp16.cpp b/example/39_permute/permute_HxWx4_fp16.cpp index 342aa134ec5570c84447b137e5be165c9a5f697c..6c24919ded6a0c65556e305455861ca137203b3e 100644 --- a/example/39_permute/permute_HxWx4_fp16.cpp +++ b/example/39_permute/permute_HxWx4_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/39_permute/permute_NxHxW_fp16.cpp b/example/39_permute/permute_NxHxW_fp16.cpp index b53975eb2c8632583f61afde50a03eba32c17c9b..3551d2a7c8decf31a48ca628d99a53639f4fa1fe 100644 --- a/example/39_permute/permute_NxHxW_fp16.cpp +++ b/example/39_permute/permute_NxHxW_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/39_permute/run_permute_bundle_example.inc b/example/39_permute/run_permute_bundle_example.inc index 70406d63f91b95f4d1cce025051a74a0ee3d114e..2c198729226243fb1fb97db6beefa1fd1f62f625 100644 --- a/example/39_permute/run_permute_bundle_example.inc +++ b/example/39_permute/run_permute_bundle_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/39_permute/run_permute_element_example.inc b/example/39_permute/run_permute_element_example.inc index bc6235303039c7cb33a4057446137a9f18ddb186..35871344567ddf62a4fb3c2bd6a2e77cb94a1998 100644 --- a/example/39_permute/run_permute_element_example.inc +++ b/example/39_permute/run_permute_element_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/40_conv2d_fwd_quantization/CMakeLists.txt b/example/40_conv2d_fwd_quantization/CMakeLists.txt index c3540d6ee6c90300e4dd9ebd476f8c27b678297b..b82a013b58f8e855615b494cf07bd1f03ce43556 100644 --- a/example/40_conv2d_fwd_quantization/CMakeLists.txt +++ b/example/40_conv2d_fwd_quantization/CMakeLists.txt @@ -1,16 +1,28 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp) + add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp) + add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp) + add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp) + set(target 1) + endif() +endforeach() # Conv perlayer quantization add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp) -add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp) # Conv perchannel quantization add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp) -add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp) # Conv + bias + relu perlayer quantization add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp) -add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp) # Conv + bias + relu perchannel quantization add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp) -add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp) +# Conv + bias + tanh perlayer quantization +add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp) + +# Conv + bias + tanh perchannel quantization +add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp) \ No newline at end of file diff --git a/example/40_conv2d_fwd_quantization/common.hpp b/example/40_conv2d_fwd_quantization/common.hpp index 6ee14d750ef6c9f8308ff7b17b86626fc70ca2b8..266b09145c84c456aedb2182554ced7815457410 100644 --- a/example/40_conv2d_fwd_quantization/common.hpp +++ b/example/40_conv2d_fwd_quantization/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp index df10e80396317bade24388e9a14a8f24289fb030..40b33852b5778f46e43ac9f527dfaec19c403697 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" @@ -76,6 +76,10 @@ using DeviceGroupedConvNDFwdInstance = 5, // CThreadTransferSrcDstVectorDim 4>; // CThreadTransferDstScalarPerVector -#include "run_conv2d_fwd_bias_relu_perchannel_quantization_example.inc" +#include "run_conv2d_fwd_bias_perchannel_quantization_example.inc" -int main() { run_conv2d_fwd_bias_relu_perchannel_quantization_example(); }; +int main() +{ + const auto out_element_op = OutElementOp{ActivationOp{}}; + run_conv2d_fwd_bias_perchannel_quantization_example(out_element_op); +}; diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp index 18f9197b9cbf60f7de72dae77feefbee06f5ed4c..fc081ddc543a43b364f3538af8f3aa89edc2c164 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" @@ -74,6 +74,11 @@ using DeviceGroupedConvNDFwdInstance = 5, // CThreadTransferSrcDstVectorDim 4>; // CThreadTransferDstScalarPerVector -#include "run_conv2d_fwd_bias_relu_perlayer_quantization_example.inc" +#include "run_conv2d_fwd_bias_perlayer_quantization_example.inc" -int main() { run_conv2d_fwd_bias_relu_perlayer_quantization_example(); } +int main() +{ + float requant_scale = 0.5f; + const auto out_element_op = OutElementOp{requant_scale, ActivationOp{}}; + run_conv2d_fwd_bias_perlayer_quantization_example(out_element_op); +} diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c390f016a97c39d9c0e400f0e4ac826b41a18d69 --- /dev/null +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using BiasDataType = int32_t; +using RequantScaleDataType = float; +using AccDataType = int32_t; +using OutDataType = int8_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using ActivationOp = ck::tensor_operation::element_wise::TanH; +using OutElementOp = + ck::tensor_operation::element_wise::Add_Mul2_Activation_Mul_Clamp; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< + NDimSpatial, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + AccDataType, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 4, // K1 + 4, // M1PerThread + 4, // N1PerThread + 1, // KPerThread + S<8, 2>, // M1N1ThreadClusterM1Xs + S<8, 2>, // M1N1ThreadClusterN1Xs + S<8, 1, 1, 4>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1 + S<2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1 + S<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder + S<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder + S<4, 1, 1, 4>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 + S<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 4>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 + S<8, 1, 1, 4>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1 + S<2, 1, 128, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1 + S<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder + S<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder + S<4, 1, 1, 4>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 + S<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 4>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 + S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder + 5, // CThreadTransferSrcDstVectorDim + 4>; // CThreadTransferDstScalarPerVector + +#include "run_conv2d_fwd_bias_perchannel_quantization_example.inc" + +int main() +{ + float scale_z_inv = 0.5f; + const auto out_element_op = OutElementOp{scale_z_inv, ActivationOp{}}; + run_conv2d_fwd_bias_perchannel_quantization_example(out_element_op); +}; diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..10b131a5271bf7347054b2bb428ddf0311fef349 --- /dev/null +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using BiasDataType = int32_t; +using AccDataType = int32_t; +using OutDataType = int8_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using ActivationOp = ck::tensor_operation::element_wise::TanH; +using OutElementOp = ck::tensor_operation::element_wise::Add_Mul_Activation_Mul_Clamp; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< + NDimSpatial, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + AccDataType, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 4, // K1 + 4, // M1PerThread + 4, // N1PerThread + 1, // KPerThread + S<8, 2>, // M1N1ThreadClusterM1Xs + S<8, 2>, // M1N1ThreadClusterN1Xs + S<8, 1, 1, 4>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1 + S<2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1 + S<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder + S<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder + S<4, 1, 1, 4>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 + S<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 4>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 + S<8, 1, 1, 4>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1 + S<2, 1, 128, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1 + S<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder + S<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder + S<4, 1, 1, 4>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 + S<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 4>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 + S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder + 5, // CThreadTransferSrcDstVectorDim + 4>; // CThreadTransferDstScalarPerVector + +#include "run_conv2d_fwd_bias_perlayer_quantization_example.inc" + +int main() +{ + float scale_acc = 0.5f; + float scale_z_inv = 0.5f; + const auto out_element_op = OutElementOp{scale_z_inv, scale_acc, ActivationOp{}}; + run_conv2d_fwd_bias_perlayer_quantization_example(out_element_op); +} diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp index afff7f8b69a81fbab00d85a9bc903433552c26b7..e59d0d075959446b8164e50db7119e8bd1d1cf54 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" @@ -76,4 +76,8 @@ using DeviceGroupedConvNDFwdInstance = #include "run_conv2d_fwd_perchannel_quantization_example.inc" -int main() { run_conv2d_fwd_perchannel_quantization_example(); } +int main() +{ + const auto out_element_op = OutElementOp{ActivationOp{}}; + run_conv2d_fwd_perchannel_quantization_example(out_element_op); +} diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp index a38fe2a6c307b194103ae6560f20987a0b2660cb..aee5fe9e60b0a3dccb3c3209a297001c55bd8ff1 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" @@ -71,4 +71,9 @@ using DeviceGroupedConvNDFwdInstance = #include "run_conv2d_fwd_perlayer_quantization_example.inc" -int main() { run_conv2d_fwd_perlayer_quantization_example(); } +int main() +{ + float requant_scale = 0.5f; + const auto out_element_op = OutElementOp{requant_scale, ActivationOp{}}; + run_conv2d_fwd_perlayer_quantization_example(out_element_op); +} diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp index ba6990d9383655667077af5ec6bbab3dc07839cd..06c839e4e226e082e9754f12ce39cececbd36269 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" @@ -80,6 +80,10 @@ using DeviceGroupedConvNDFwdInstance = S<1, 64, 1, 4>, 8>; -#include "run_conv2d_fwd_bias_relu_perchannel_quantization_example.inc" +#include "run_conv2d_fwd_bias_perchannel_quantization_example.inc" -int main() { run_conv2d_fwd_bias_relu_perchannel_quantization_example(); }; +int main() +{ + const auto out_element_op = OutElementOp{ActivationOp{}}; + run_conv2d_fwd_bias_perchannel_quantization_example(out_element_op); +}; diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp index 690d70e1127cc2dbc1950ba190b40f65e5f32d1d..7a9b42d39f971ab1c36476f4c0ead172cb6d2606 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" @@ -78,6 +78,11 @@ using DeviceGroupedConvNDFwdInstance = S<1, 64, 1, 4>, 8>; -#include "run_conv2d_fwd_bias_relu_perlayer_quantization_example.inc" +#include "run_conv2d_fwd_bias_perlayer_quantization_example.inc" -int main() { run_conv2d_fwd_bias_relu_perlayer_quantization_example(); } +int main() +{ + float requant_scale = 0.5f; + const auto out_element_op = OutElementOp{requant_scale, ActivationOp{}}; + run_conv2d_fwd_bias_perlayer_quantization_example(out_element_op); +} diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp index dd755ff0650b68a473c905ca2bcadc964fe2c686..3495636297d5d83fb3ccc7d236e6135de9b715fd 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" @@ -80,4 +80,8 @@ using DeviceGroupedConvNDFwdInstance = #include "run_conv2d_fwd_perchannel_quantization_example.inc" -int main() { run_conv2d_fwd_perchannel_quantization_example(); } +int main() +{ + const auto out_element_op = OutElementOp{ActivationOp{}}; + run_conv2d_fwd_perchannel_quantization_example(out_element_op); +} diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp index 48617e4775e0d211a9ba161b13e61a59684ea064..2611337254212b326f6eef540c6a4eca5cd19261 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" @@ -75,4 +75,9 @@ using DeviceGroupedConvNDFwdInstance = #include "run_conv2d_fwd_perlayer_quantization_example.inc" -int main() { run_conv2d_fwd_perlayer_quantization_example(); } +int main() +{ + float requant_scale = 0.5f; + const auto out_element_op = OutElementOp{requant_scale, ActivationOp{}}; + run_conv2d_fwd_perlayer_quantization_example(out_element_op); +} diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_relu_perchannel_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc similarity index 96% rename from example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_relu_perchannel_quantization_example.inc rename to example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc index 822a1ed8b5fca79e40f959ab723053a2bac0e7bb..e5b924ad5114a6e31a1ba2875118085329d5e8ce 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_relu_perchannel_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once template (conv_param); diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_relu_perlayer_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc similarity index 96% rename from example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_relu_perlayer_quantization_example.inc rename to example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc index 00cbaa09ee89e637f6f6166d1229a337fe905bc1..9f3a769dcf0979d407169610f5cdbe15dbccb26e 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_relu_perlayer_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -155,7 +155,7 @@ bool run_grouped_conv_fwd(bool do_verification, return (pass ? 0 : 1); } -int run_conv2d_fwd_bias_relu_perlayer_quantization_example() +int run_conv2d_fwd_bias_perlayer_quantization_example(const OutElementOp& out_element_op) { bool do_verification = true; bool time_kernel = true; @@ -177,12 +177,11 @@ int run_conv2d_fwd_bias_relu_perlayer_quantization_example() const auto in_element_op = InElementOp{}; const auto wei_element_op = WeiElementOp{}; - const auto out_element_op = OutElementOp{0.5f, ActivationOp{}}; - using InLayout = ck::tensor_layout::convolution::GNHWC; - using WeiLayout = ck::tensor_layout::convolution::GKYXC; + using InLayout = ck::tensor_layout::convolution::NHWGC; + using WeiLayout = ck::tensor_layout::convolution::KYXGC; using BiasLayout = ck::tensor_layout::convolution::G_K; - using OutLayout = ck::tensor_layout::convolution::GNHWK; + using OutLayout = ck::tensor_layout::convolution::NHWGK; const auto in_g_n_c_wis_desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc index 2e0623028d19ae510a80c9efc1d2c6cb7ee5d43f..9b08fc690d8298b87ce3ebf89f462f3739734815 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -157,7 +157,7 @@ bool run_grouped_conv_fwd(bool do_verification, return (pass ? 0 : 1); } -int run_conv2d_fwd_perchannel_quantization_example() +int run_conv2d_fwd_perchannel_quantization_example(const OutElementOp& out_element_op) { bool do_verification = true; bool time_kernel = true; @@ -179,12 +179,11 @@ int run_conv2d_fwd_perchannel_quantization_example() const auto in_element_op = InElementOp{}; const auto wei_element_op = WeiElementOp{}; - const auto out_element_op = OutElementOp{ActivationOp{}}; - using InLayout = ck::tensor_layout::convolution::GNHWC; - using WeiLayout = ck::tensor_layout::convolution::GKYXC; + using InLayout = ck::tensor_layout::convolution::NHWGC; + using WeiLayout = ck::tensor_layout::convolution::KYXGC; using RequantScaleLayout = ck::tensor_layout::convolution::G_K; - using OutLayout = ck::tensor_layout::convolution::GNHWK; + using OutLayout = ck::tensor_layout::convolution::NHWGK; const auto in_g_n_c_wis_desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc index aeccb30cf227616210c25fa1cf8ae09c17f1b0f3..267c737e00fc5046e06551ca3e366d01dc78d230 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -139,7 +139,7 @@ bool run_grouped_conv_fwd(bool do_verification, return (pass ? 0 : 1); } -int run_conv2d_fwd_perlayer_quantization_example() +int run_conv2d_fwd_perlayer_quantization_example(const OutElementOp& out_element_op) { bool do_verification = true; bool time_kernel = false; @@ -161,11 +161,10 @@ int run_conv2d_fwd_perlayer_quantization_example() const auto in_element_op = InElementOp{}; const auto wei_element_op = WeiElementOp{}; - const auto out_element_op = OutElementOp{0.5f, ActivationOp{}}; - using InLayout = ck::tensor_layout::convolution::GNHWC; - using WeiLayout = ck::tensor_layout::convolution::GKYXC; - using OutLayout = ck::tensor_layout::convolution::GNHWK; + using InLayout = ck::tensor_layout::convolution::NHWGC; + using WeiLayout = ck::tensor_layout::convolution::KYXGC; + using OutLayout = ck::tensor_layout::convolution::NHWGK; const auto in_g_n_c_wis_desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); diff --git a/example/41_grouped_conv_conv_fwd/CMakeLists.txt b/example/41_grouped_conv_conv_fwd/CMakeLists.txt index 9cb30f61760a47a2f72c98e31de2b522b97543ca..ae251e88d205a04968eab430b52c004a1136d853 100644 --- a/example/41_grouped_conv_conv_fwd/CMakeLists.txt +++ b/example/41_grouped_conv_conv_fwd/CMakeLists.txt @@ -1,8 +1,18 @@ -add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp) -add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp) -add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp) -add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp) +list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list2 gfx908 gfx90a) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list1 AND target EQUAL 0) + add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp) + add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp) + add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp) + if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp) + endif(USE_BITINT_EXTENSION_INT4) + set(target 1) + endif() +endforeach() -if(USE_BITINT_EXTENSION_INT4) -add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp) -endif(USE_BITINT_EXTENSION_INT4) +if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") + add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp) +endif() diff --git a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_bf16.cpp b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_bf16.cpp index 2aea08c40004fac22ea0e45521ad09fd806751d8..e37d41369599819bdf1b9836c1a0587be2e31a0f 100644 --- a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_bf16.cpp +++ b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp16.cpp b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp16.cpp index b7f80e76d6ce5a690beef4f48361e0a65402cb47..496e676a402f4d12a325c2afef56ec422cb4d721 100644 --- a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp16.cpp +++ b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp32.cpp b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp32.cpp index 15e460948ef528b697a4e3bad7d40ca332f85c15..35d50721dcea27b2bfb42eea6e39da08709dfc7d 100644 --- a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp32.cpp +++ b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp index 2cc4c07c0d89fc83b233b50967cd595dc0e53ac7..80f6e9ae05712b1df4b05b1bb75fbff62d9ca70a 100644 --- a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp +++ b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #error Should compile this file with ck::int4_t support diff --git a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int8.cpp b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int8.cpp index 40ff0f69cc14cd644361f3391eea8ed52380adc0..3ade6c811ae5583bbb8a2a3b277c6f4e7f265a87 100644 --- a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int8.cpp +++ b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/41_grouped_conv_conv_fwd/run_grouped_conv_conv_fwd_example.inc b/example/41_grouped_conv_conv_fwd/run_grouped_conv_conv_fwd_example.inc index a2c97f4d421f9382aa279b257a49299c4016e2c3..0722d497d8df0ebf0acf0cf798215bebcdee6e3a 100644 --- a/example/41_grouped_conv_conv_fwd/run_grouped_conv_conv_fwd_example.inc +++ b/example/41_grouped_conv_conv_fwd/run_grouped_conv_conv_fwd_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/42_groupnorm/CMakeLists.txt b/example/42_groupnorm/CMakeLists.txt index c3b7b825920314aba3fa15e401da26860c23a3c3..e8c306ac582f56607c8bfa309c4a3f16cd64be50 100644 --- a/example/42_groupnorm/CMakeLists.txt +++ b/example/42_groupnorm/CMakeLists.txt @@ -1 +1,3 @@ -add_example_executable(example_groupnorm_sigmoid_fp16 groupnorm_sigmoid_fp16.cpp) +add_example_executable(example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp) +add_example_executable(example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp) +add_example_executable(example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp) diff --git a/example/42_groupnorm/common.hpp b/example/42_groupnorm/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c8f91eb53b5b943bfbdd264b842cf7762a176b82 --- /dev/null +++ b/example/42_groupnorm/common.hpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" + +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp" diff --git a/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp b/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cc107b63dcd87735ba40ebbe7633c61993406166 --- /dev/null +++ b/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using ComputeDataType = float; + +struct YElementOp +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(ck::is_same::value || ck::is_same::value || + ck::is_same::value, + "Data type is not supported by this operation!"); + + T a; + + ck::tensor_operation::element_wise::Sigmoid{}(a, x); + + y = x * a; + }; +}; + +using DeviceInstance = + ck::tensor_operation::device::DeviceNormalizationImpl; // OutScalarPerVector + +#include "run_groupnorm_example.inc" + +int main(int argc, char* argv[]) { run_groupnorm_example(argc, argv); } diff --git a/example/42_groupnorm/groupnorm_splitk_fp16.cpp b/example/42_groupnorm/groupnorm_splitk_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..057b240a63fc38217c095b6f3a058e1d4db6ea27 --- /dev/null +++ b/example/42_groupnorm/groupnorm_splitk_fp16.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using ComputeDataType = float; +using YElementOp = ck::tensor_operation::element_wise::Swish; + +using DeviceInstance = + ck::tensor_operation::device::DeviceNormalizationSplitKImpl; // OutScalarPerVector + +#include "run_groupnorm_example.inc" + +int main(int argc, char* argv[]) { run_groupnorm_example(argc, argv); } diff --git a/example/42_groupnorm/groupnorm_swish_fp16.cpp b/example/42_groupnorm/groupnorm_swish_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..363f22ed4c0015c9a96aa98a8b6b5302978e3d25 --- /dev/null +++ b/example/42_groupnorm/groupnorm_swish_fp16.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using ComputeDataType = float; +using YElementOp = ck::tensor_operation::element_wise::Swish; + +using DeviceInstance = + ck::tensor_operation::device::DeviceNormalizationImpl; // OutScalarPerVector + +#include "run_groupnorm_example.inc" + +int main(int argc, char* argv[]) { run_groupnorm_example(argc, argv); } diff --git a/example/42_groupnorm/groupnorm_sigmoid_fp16.cpp b/example/42_groupnorm/run_groupnorm_example.inc similarity index 53% rename from example/42_groupnorm/groupnorm_sigmoid_fp16.cpp rename to example/42_groupnorm/run_groupnorm_example.inc index 35c7c054e05688b8a09817e738bcca503ff0041d..16065c8d46a42433416e21e9607dc3fcb708d817 100644 --- a/example/42_groupnorm/groupnorm_sigmoid_fp16.cpp +++ b/example/42_groupnorm/run_groupnorm_example.inc @@ -1,80 +1,15 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/utility/reduction_enums.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" -#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" - -#include "ck/library/utility/fill.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_common_util.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp" - -constexpr int Rank = 5; -constexpr int NumReduceDim = 3; - -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using ComputeDataType = float; - -struct YElementOp -{ - template - __host__ __device__ void operator()(T& y, const T& x) const - { - static_assert(ck::is_same::value || ck::is_same::value || - ck::is_same::value, - "Data type is not supported by this operation!"); - - T a; +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - ck::tensor_operation::element_wise::Sigmoid{}(a, x); +#pragma once - y = x * a; - }; -}; - -using DeviceInstance = - ck::tensor_operation::device::DeviceNormalizationImpl; // OutScalarPerVector - -int main(int argc, char* argv[]) +int run_groupnorm_example(int argc, char* argv[]) { - ck::index_t N = 2; - ck::index_t H = 32; - ck::index_t W = 32; - ck::index_t G = 32; - ck::index_t C = 30; + ck::index_t N = 32; + ck::index_t H = 16; + ck::index_t W = 16; + ck::index_t G = 64; + ck::index_t C = 128; if(argc == 1) { @@ -138,6 +73,10 @@ int main(int argc, char* argv[]) return 1; }; + size_t workspace_sz = device_instance.GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + device_instance.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + auto invoker_ptr = device_instance.MakeInvokerPointer(); float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true, true}); diff --git a/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp b/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp index 7ac4b68272e2b0a1a5cbc09fa3734e190d6897da..b6d9b29a5b757f024ca09c52d9b8ce3f5430e555 100644 --- a/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp +++ b/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp b/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp index 764e55ef558bd977fca8f25f16b2bba4c111114b..60a0e01fe3e3db9f490304bcd4bbe322a04d3c0b 100644 --- a/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp +++ b/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp index 7d6ff12eeafe8de6f45e93567959d86d748c783c..76361f87a5b58071b149cab505f7c0485abfadab 100644 --- a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp +++ b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/46_gemm_add_multiply/common.hpp b/example/46_gemm_add_multiply/common.hpp index 3ba78dfe47ba4cac2f647bf5cb609561e6989275..2c656cf44159a594f0b8ba3b40b7d8d936d57cb2 100644 --- a/example/46_gemm_add_multiply/common.hpp +++ b/example/46_gemm_add_multiply/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/46_gemm_add_multiply/gemm_add_multiply_dl_fp16.cpp b/example/46_gemm_add_multiply/gemm_add_multiply_dl_fp16.cpp index 28c3939fa611d536600bdce9c7609705e0373da1..58a399f226c1b59073dd000e382e673badb5e1fb 100644 --- a/example/46_gemm_add_multiply/gemm_add_multiply_dl_fp16.cpp +++ b/example/46_gemm_add_multiply/gemm_add_multiply_dl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp" diff --git a/example/46_gemm_add_multiply/gemm_add_multiply_xdl_fp16.cpp b/example/46_gemm_add_multiply/gemm_add_multiply_xdl_fp16.cpp index d5aa41f1b6d5a97c5c1b614e2765a86fde7343e1..56417b101d9528674535ab93f6fdc7ab36d47c3d 100644 --- a/example/46_gemm_add_multiply/gemm_add_multiply_xdl_fp16.cpp +++ b/example/46_gemm_add_multiply/gemm_add_multiply_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" diff --git a/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt b/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt index d1b3dd4be22c07f2b8fc87947328c849ff10650b..14432f6e23d4adbf09e0df6c805a68f1a5571f69 100644 --- a/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt +++ b/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt @@ -1 +1,8 @@ -add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute.cpp) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute.cpp) + set(target 1) + endif() +endforeach() diff --git a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp index 30c98e534a99fa05ff6ebec044c5b70df7419254..a90a6340a431a55c271dca3d3d0d1771382218f5 100644 --- a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp +++ b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -121,7 +121,8 @@ using DeviceOpInstance = 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock - MaskingSpec>; // MaskingSpecialization + MaskingSpec, // MaskingSpecialization + 1>; // Ref Gemm0: fp16 in, fp32 out using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp" + +template +bool pool3d_test(bool do_verification, + bool time_kernel, + ck::index_t N, + ck::index_t C, + ck::index_t Z, + ck::index_t Y, + ck::index_t X, + ck::index_t Di, + ck::index_t Hi, + ck::index_t Wi, + ck::index_t window_stride_d, + ck::index_t window_stride_h, + ck::index_t window_stride_w, + ck::index_t in_left_pad_d, + ck::index_t in_left_pad_h, + ck::index_t in_left_pad_w, + ck::index_t in_right_pad_d, + ck::index_t in_right_pad_h, + ck::index_t in_right_pad_w) +{ + using DevicePoolFwdInstance = + ck::tensor_operation::device::DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C< + InDataType, // InDataType + OutDataType, // OutDataType + IndexDataType, // IndexDataType + ComputeDataType, // ComputeDataType + ReduceOpId, + OutputIndex, + 64, // BlockSize + 64, // ReduceMThreadClusterSize + 1, // ReduceKThreadClusterSize + 4, // ReduceMThreadSliceSize + 1, // ReduceKThreadSliceSize + 4>; // InSrcOutDstVectorSize + + const ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Z) / window_stride_d + 1; + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1; + + const std::vector window_spatial_lengths{Z, Y, X}; + const std::vector window_strides{ + window_stride_d, window_stride_h, window_stride_w}; + const std::vector input_left_pads{in_left_pad_d, in_left_pad_h, in_left_pad_w}; + const std::vector input_right_pads{in_right_pad_d, in_right_pad_h, in_right_pad_w}; + + // tensor layout + auto f_host_tensor_descriptor = [](std::size_t N_, + std::size_t C_, + std::size_t D, + std::size_t H, + std::size_t W, + auto layout) { + using namespace ck::literals; + + if constexpr(ck::is_same::value) + { + return HostTensorDescriptor({N_, C_, D, H, W}, + {C_ * D * H * W, D * H * W, H * W, W, 1_uz}); + } + else if constexpr(ck::is_same::value) + { + return HostTensorDescriptor({N_, C_, D, H, W}, + {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); + } + }; + + Tensor in_n_c_di_hi_wi(f_host_tensor_descriptor(N, C, Di, Hi, Wi, InLayout{})); + Tensor out_n_c_do_ho_wo_host( + f_host_tensor_descriptor(N, C, Do, Ho, Wo, OutLayout{})); + Tensor out_indices_n_c_do_ho_wo_host( + f_host_tensor_descriptor(N, C, Do, Ho, Wo, OutLayout{})); + Tensor out_n_c_do_ho_wo_device( + f_host_tensor_descriptor(N, C, Do, Ho, Wo, OutLayout{})); + Tensor out_indices_n_c_do_ho_wo_device( + f_host_tensor_descriptor(N, C, Do, Ho, Wo, OutLayout{})); + + std::cout << "in_n_c_di_hi_wi: " << in_n_c_di_hi_wi.mDesc << std::endl; + std::cout << "out_n_c_do_ho_wo: " << out_n_c_do_ho_wo_host.mDesc << std::endl; + + in_n_c_di_hi_wi.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_di_hi_wi.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * + out_n_c_do_ho_wo_device.mDesc.GetElementSpaceSize()); + DeviceMem out_indices_device_buf(sizeof(IndexDataType) * + out_indices_n_c_do_ho_wo_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in_n_c_di_hi_wi.mData.data()); + + auto pool = DevicePoolFwdInstance{}; + auto invoker_ptr = pool.MakeInvokerPointer(); + auto argument_ptr = pool.MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(out_indices_device_buf.GetDeviceBuffer()), + {N, C, Di, Hi, Wi}, + {Z, Y, X}, + {N, C, Do, Ho, Wo}, + {Di * C * Hi * Wi, 1, C * Hi * Wi, Wi * C, C}, + {Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C}, + {Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C}, + window_strides, + input_left_pads, + input_right_pads, + {2, 3, 4}); + + if(!pool.IsSupportedArgument(argument_ptr.get())) + { + throw std::runtime_error("wrong! device_op with the specified compilation parameters does " + "not support this problem"); + } + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + std::cout << "Perf: " << ave_time << std::endl; + + bool pass = true; + + if(do_verification) + { + using ReferencePoolingFwdInstance = + ck::tensor_operation::host::ReferencePoolingFwd<5, + 3, + InDataType, + OutDataType, + ComputeDataType, + IndexDataType, + ReduceOpId, + PropagateNan, + OutputIndex>; + + auto ref_pooling = ReferencePoolingFwdInstance{}; + auto ref_pooling_invoker = ref_pooling.MakeInvoker(); + auto ref_pooling_argument = ref_pooling.MakeArgument(in_n_c_di_hi_wi, + out_n_c_do_ho_wo_host, + out_indices_n_c_do_ho_wo_host, + window_spatial_lengths, + window_strides, + input_left_pads, + input_right_pads); + + ref_pooling_invoker.Run(ref_pooling_argument); + + out_device_buf.FromDevice(out_n_c_do_ho_wo_device.mData.data()); + + pass = pass && ck::utils::check_err(out_n_c_do_ho_wo_device, out_n_c_do_ho_wo_host); + + if constexpr(OutputIndex) + { + out_indices_device_buf.FromDevice(out_indices_n_c_do_ho_wo_device.mData.data()); + + pass = pass && ck::utils::check_err(out_indices_n_c_do_ho_wo_device, + out_indices_n_c_do_ho_wo_host); + }; + } + + return (pass); +}; diff --git a/example/48_pool3d_fwd/pool3d_fwd_fp16.cpp b/example/48_pool3d_fwd/pool3d_fwd_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9afb51201deebe47b71932ab8fe1b9ffb50edcd6 --- /dev/null +++ b/example/48_pool3d_fwd/pool3d_fwd_fp16.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "pool3d_fwd_common.hpp" + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using ComputeDataType = float; + +using IndexDataType = int32_t; + +using InLayout = ck::tensor_layout::convolution::NDHWC; +using OutLayout = ck::tensor_layout::convolution::NDHWC; + +#if 1 +static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; +#else +static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; +#endif + +static constexpr bool OutputIndex = false; +static constexpr bool PropagateNan = false; + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + // Pool shape + ck::index_t N = 2; + ck::index_t C = 32; + ck::index_t Z = 2; + ck::index_t Y = 2; + ck::index_t X = 2; + ck::index_t Di = 30; + ck::index_t Hi = 30; + ck::index_t Wi = 30; + ck::index_t window_stride_d = 2; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t in_left_pad_d = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_d = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + bool pass = pool3d_test(do_verification, + time_kernel, + N, + C, + Z, + Y, + X, + Di, + Hi, + Wi, + window_stride_d, + window_stride_h, + window_stride_w, + in_left_pad_d, + in_left_pad_h, + in_left_pad_w, + in_right_pad_d, + in_right_pad_h, + in_right_pad_w); + + return (pass ? 0 : 1); +} diff --git a/example/49_maxpool2d_bwd/CMakeLists.txt b/example/49_maxpool2d_bwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..b29cf9ccbce89d48bcff009d46b94e8c8e036741 --- /dev/null +++ b/example/49_maxpool2d_bwd/CMakeLists.txt @@ -0,0 +1,3 @@ +add_example_executable(example_maxpool2d_bwd_bf16 maxpool2d_bwd_bf16.cpp) +add_example_executable(example_maxpool2d_bwd_fp16 maxpool2d_bwd_fp16.cpp) +add_example_executable(example_maxpool2d_bwd_fp32 maxpool2d_bwd_fp32.cpp) diff --git a/example/49_maxpool2d_bwd/maxpool2d_bwd_bf16.cpp b/example/49_maxpool2d_bwd/maxpool2d_bwd_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..08a8009c6c35cc3f9c7881efb28daa571c96b5cd --- /dev/null +++ b/example/49_maxpool2d_bwd/maxpool2d_bwd_bf16.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "maxpool2d_bwd_common.hpp" + +using InDataType = ck::bhalf_t; +using OutDataType = ck::bhalf_t; +using IndexDataType = int32_t; +using ComputeDataType = float; +using DInDataType = ck::bhalf_t; +using DOutDataType = ck::bhalf_t; + +static constexpr bool PropagateNan = false; + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + // Pool shape + ck::index_t N = 1; + ck::index_t C = 1; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 32; + ck::index_t Wi = 32; + ck::index_t window_stride_h = 1; + ck::index_t window_stride_w = 1; + ck::index_t in_left_pad_h = 0; + ck::index_t in_left_pad_w = 0; + ck::index_t in_right_pad_h = 0; + ck::index_t in_right_pad_w = 0; + + bool pass = maxpool_bwd_test(do_verification, + time_kernel, + N, + C, + Y, + X, + Hi, + Wi, + window_stride_h, + window_stride_w, + in_left_pad_h, + in_left_pad_w, + in_right_pad_h, + in_right_pad_w); + + return (pass ? 0 : 1); +} diff --git a/example/49_maxpool2d_bwd/maxpool2d_bwd_common.hpp b/example/49_maxpool2d_bwd/maxpool2d_bwd_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..045793cc2ae1ca9684ac5142821782d2c03cdb44 --- /dev/null +++ b/example/49_maxpool2d_bwd/maxpool2d_bwd_common.hpp @@ -0,0 +1,222 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_index_pool_bwd_impl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_maxpool_bwd.hpp" + +template +bool maxpool_bwd_test(bool do_verification, + bool time_kernel, + ck::index_t N, + ck::index_t C, + ck::index_t Y, + ck::index_t X, + ck::index_t Hi, + ck::index_t Wi, + ck::index_t window_stride_h, + ck::index_t window_stride_w, + ck::index_t in_left_pad_h, + ck::index_t in_left_pad_w, + ck::index_t in_right_pad_h, + ck::index_t in_right_pad_w) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using DevicePoolFwdInstance = + ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C< + InDataType, // InDataType + OutDataType, // OutDataType + IndexDataType, // IndexDataType + ComputeDataType, // ComputeDataType + ck::ReduceTensorOp::MAX, + true, // OutputIndex + 64, // BlockSize + 64, // ReduceMThreadClusterSize + 1, // ReduceKThreadClusterSize + 4, // ReduceMThreadSliceSize + 1, // ReduceKThreadSliceSize + 1>; // InSrcOutDstVectorSize + + using DeviceMaxPoolBwdInstance = ck::tensor_operation::device:: + DeviceIndexPoolBwdImpl; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1; + + const std::vector window_spatial_lengths{Y, X}; + const std::vector window_strides{window_stride_h, window_stride_w}; + const std::vector input_left_pads{in_left_pad_h, in_left_pad_w}; + const std::vector input_right_pads{in_right_pad_h, in_right_pad_w}; + + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W) { + using namespace ck::literals; + // reference need Tensor with NCHW order + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}); + }; + + // in + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi)); + + // out + Tensor out_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo)); + Tensor out_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo)); + + // indices + Tensor indices_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo)); + Tensor indices_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo)); + + // dout + Tensor dout_n_c_ho_wo(f_host_tensor_descriptor(N, C, Ho, Wo)); + + // din + Tensor din_n_c_hi_wi_host(f_host_tensor_descriptor(N, C, Hi, Wi)); + Tensor din_n_c_hi_wi_device(f_host_tensor_descriptor(N, C, Hi, Wi)); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "out_n_c_ho_wo: " << out_n_c_ho_wo_host.mDesc << std::endl; + std::cout << "indices_n_c_ho_wo: " << indices_n_c_ho_wo_host.mDesc << std::endl; + std::cout << "dout_n_c_ho_wo: " << dout_n_c_ho_wo.mDesc << std::endl; + std::cout << "din_n_c_hi_wi: " << din_n_c_hi_wi_host.mDesc << std::endl; + + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + dout_n_c_ho_wo.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * + out_n_c_ho_wo_device.mDesc.GetElementSpaceSize()); + DeviceMem indices_device_buf(sizeof(IndexDataType) * + indices_n_c_ho_wo_device.mDesc.GetElementSpaceSize()); + DeviceMem dout_device_buf(sizeof(DOutDataType) * dout_n_c_ho_wo.mDesc.GetElementSpaceSize()); + DeviceMem din_device_buf(sizeof(DInDataType) * + din_n_c_hi_wi_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + dout_device_buf.ToDevice(dout_n_c_ho_wo.mData.data()); + + auto pool_fwd = DevicePoolFwdInstance{}; + auto pool_fwd_invoker_ptr = pool_fwd.MakeInvokerPointer(); + auto pool_fwd_argument_ptr = pool_fwd.MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(indices_device_buf.GetDeviceBuffer()), + {N, C, Hi, Wi}, + window_spatial_lengths, + {N, C, Ho, Wo}, + {C * Hi * Wi, 1, Wi * C, C}, + {C * Ho * Wo, 1, Wo * C, C}, + {C * Ho * Wo, 1, Wo * C, C}, + window_strides, + input_left_pads, + input_right_pads, + {2, 3}); + + if(!pool_fwd.IsSupportedArgument(pool_fwd_argument_ptr.get())) + { + throw std::runtime_error("wrong! pool_fwd with the specified compilation parameters does " + "not support this problem"); + } + + float ave_time_fwd = + pool_fwd_invoker_ptr->Run(pool_fwd_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + auto pool_bwd = DeviceMaxPoolBwdInstance{}; + auto pool_bwd_invoker_ptr = pool_bwd.MakeInvokerPointer(); + auto pool_bwd_argument_ptr = pool_bwd.MakeArgumentPointer( + static_cast(dout_device_buf.GetDeviceBuffer()), + static_cast(indices_device_buf.GetDeviceBuffer()), + static_cast(din_device_buf.GetDeviceBuffer()), + dout_n_c_ho_wo.mDesc.GetElementSpaceSize(), + din_n_c_hi_wi_device.mDesc.GetElementSpaceSize(), + window_spatial_lengths, + window_strides); + + if(!pool_bwd.IsSupportedArgument(pool_bwd_argument_ptr.get())) + { + throw std::runtime_error("wrong! pool_bwd with the specified compilation parameters does " + "not support this problem"); + } + + size_t pool_bwd_workspace_sz = pool_bwd.GetWorkSpaceSize(pool_bwd_argument_ptr.get()); + DeviceMem pool_bwd_workspace_device_buf(pool_bwd_workspace_sz); + pool_bwd.SetWorkSpacePointer(pool_bwd_argument_ptr.get(), + pool_bwd_workspace_device_buf.GetDeviceBuffer()); + + float ave_time_bwd = + pool_bwd_invoker_ptr->Run(pool_bwd_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "Pool fwd perf: " << ave_time_fwd << " ms" << std::endl; + std::cout << "Pool bwd perf: " << ave_time_bwd << " ms" << std::endl; + + bool pass = true; + + if(do_verification) + { + using ReferencePoolingFwdInstance = + ck::tensor_operation::host::ReferencePoolingFwd<4, + 2, + InDataType, + OutDataType, + ComputeDataType, + IndexDataType, + ck::ReduceTensorOp::MAX, + PropagateNan, + true>; + + auto ref_pooling_fwd = ReferencePoolingFwdInstance{}; + auto ref_pooling_fwd_invoker = ref_pooling_fwd.MakeInvoker(); + auto ref_pooling_fwd_argument = ref_pooling_fwd.MakeArgument(in_n_c_hi_wi, + out_n_c_ho_wo_host, + indices_n_c_ho_wo_host, + window_spatial_lengths, + window_strides, + input_left_pads, + input_right_pads); + ref_pooling_fwd_invoker.Run(ref_pooling_fwd_argument); + + using ReferencePoolingBwdInstance = + ck::tensor_operation::host::ReferenceMaxPoolBwd; + + auto ref_pooling_bwd = ReferencePoolingBwdInstance{}; + auto ref_pooling_bwd_invoker = ref_pooling_bwd.MakeInvoker(); + auto ref_pooling_bwd_argument = ref_pooling_bwd.MakeArgument( + dout_n_c_ho_wo, indices_n_c_ho_wo_host, din_n_c_hi_wi_host, PassThrough{}); + + ref_pooling_bwd_invoker.Run(ref_pooling_bwd_argument); + + out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data()); + indices_device_buf.FromDevice(indices_n_c_ho_wo_device.mData.data()); + din_device_buf.FromDevice(din_n_c_hi_wi_device.mData.data()); + + pass = pass && ck::utils::check_err(out_n_c_ho_wo_device, out_n_c_ho_wo_host); + pass = pass && ck::utils::check_err(indices_n_c_ho_wo_device, indices_n_c_ho_wo_host); + pass = pass && ck::utils::check_err(din_n_c_hi_wi_device, din_n_c_hi_wi_host); + } + + return (pass); +}; diff --git a/example/49_maxpool2d_bwd/maxpool2d_bwd_fp16.cpp b/example/49_maxpool2d_bwd/maxpool2d_bwd_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c6b0b2c86006ddb30788a06ffac7fdb177f3abca --- /dev/null +++ b/example/49_maxpool2d_bwd/maxpool2d_bwd_fp16.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "maxpool2d_bwd_common.hpp" + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using IndexDataType = int32_t; +using ComputeDataType = float; +using DInDataType = ck::half_t; +using DOutDataType = ck::half_t; + +static constexpr bool PropagateNan = false; + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + // Pool shape + ck::index_t N = 1; + ck::index_t C = 1; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 32; + ck::index_t Wi = 32; + ck::index_t window_stride_h = 1; + ck::index_t window_stride_w = 1; + ck::index_t in_left_pad_h = 0; + ck::index_t in_left_pad_w = 0; + ck::index_t in_right_pad_h = 0; + ck::index_t in_right_pad_w = 0; + + bool pass = maxpool_bwd_test(do_verification, + time_kernel, + N, + C, + Y, + X, + Hi, + Wi, + window_stride_h, + window_stride_w, + in_left_pad_h, + in_left_pad_w, + in_right_pad_h, + in_right_pad_w); + + return (pass ? 0 : 1); +} diff --git a/example/49_maxpool2d_bwd/maxpool2d_bwd_fp32.cpp b/example/49_maxpool2d_bwd/maxpool2d_bwd_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c79b84c489d637e3c3ec9b52fd726537af3f776b --- /dev/null +++ b/example/49_maxpool2d_bwd/maxpool2d_bwd_fp32.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "maxpool2d_bwd_common.hpp" + +using InDataType = float; +using OutDataType = float; +using IndexDataType = int32_t; +using ComputeDataType = float; +using DInDataType = float; +using DOutDataType = float; + +static constexpr bool PropagateNan = false; + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + // Pool shape + ck::index_t N = 1; + ck::index_t C = 1; + ck::index_t Y = 2; + ck::index_t X = 2; + ck::index_t Hi = 32; + ck::index_t Wi = 32; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t in_left_pad_h = 0; + ck::index_t in_left_pad_w = 0; + ck::index_t in_right_pad_h = 0; + ck::index_t in_right_pad_w = 0; + + bool pass = maxpool_bwd_test(do_verification, + time_kernel, + N, + C, + Y, + X, + Hi, + Wi, + window_stride_h, + window_stride_w, + in_left_pad_h, + in_left_pad_w, + in_right_pad_h, + in_right_pad_w); + + return (pass ? 0 : 1); +} diff --git a/example/50_put_element/CMakeLists.txt b/example/50_put_element/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..1b0020ebcf65a87ad39b004c149385cdd7de89ae --- /dev/null +++ b/example/50_put_element/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_put_element_fp16 put_element_fp16.cpp) diff --git a/example/50_put_element/put_element_fp16.cpp b/example/50_put_element/put_element_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d4b6831bc4713bac25358998df3e9abfb9785e80 --- /dev/null +++ b/example/50_put_element/put_element_fp16.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using XDataType = ck::half_t; +using YDataType = ck::half_t; +using IndexDataType = int32_t; + +using YElementwiseOp = ck::tensor_operation::element_wise::PassThrough; + +using DeviceInstance = + ck::tensor_operation::device::DevicePutElementImpl; + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + int N = 1024; + + Tensor x(HostTensorDescriptor{N, 1}); + Tensor indices(HostTensorDescriptor{N, 1}); + Tensor y(HostTensorDescriptor{N, 1}); + + x.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + for(int i = 0; i < N; ++i) + indices(i) = i; + + DeviceMem x_device_buf(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem y_device_buf(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); + DeviceMem indices_device_buf(sizeof(IndexDataType) * indices.mDesc.GetElementSpaceSize()); + + x_device_buf.ToDevice(x.mData.data()); + indices_device_buf.ToDevice(indices.mData.data()); + + auto put_instance = DeviceInstance{}; + auto put_invoker_ptr = put_instance.MakeInvokerPointer(); + auto put_argument_ptr = put_instance.MakeArgumentPointer( + static_cast(x_device_buf.GetDeviceBuffer()), + static_cast(indices_device_buf.GetDeviceBuffer()), + static_cast(y_device_buf.GetDeviceBuffer()), + N, + N, + YElementwiseOp{}); + + if(!put_instance.IsSupportedArgument(put_argument_ptr.get())) + { + throw std::runtime_error("argument is not supported!"); + } + + float ave_time = + put_invoker_ptr->Run(put_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "perf: " << ave_time << " ms" << std::endl; + + bool pass = true; + if(do_verification) + { + Tensor y_host(HostTensorDescriptor{N, 1}); + + for(int i = 0; i < N; ++i) + { + IndexDataType idx = indices(i); + y_host(idx) = x(i); + } + + y_device_buf.FromDevice(y.mData.data()); + pass = ck::utils::check_err(y, y_host); + } + + return (pass ? 0 : 1); +} diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 5009dec5e96f4521eabaad5351f3664be97c829a..161c3261ad39272f0040c8c289167e3890b10ddb 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -31,20 +31,21 @@ #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_BUFFER_RESOURCE_3RD_DWORD -1 #elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) // for GPU code + defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__) // for GPU code #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #elif defined(__gfx1030__) // for GPU code #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 #elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code -#define CK_BUFFER_RESOURCE_3RD_DWORD 0x10020000 +#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000 #endif // FMA instruction #ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing #elif defined(__gfx803__) || defined(__gfx900__) // for GPU code #define CK_USE_AMD_V_MAC_F32 -#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx1030__) // for GPU code +#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // for GPU code #define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT4_I32_I8 @@ -53,14 +54,19 @@ // MFMA instruction #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_MFMA -#elif defined(__gfx908__) || defined(__gfx90a__) // for GPU code +#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__) // for GPU code #define CK_USE_AMD_MFMA #endif -#if defined(__gfx90a__) +#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) #define CK_USE_AMD_MFMA_BF16_1K_OP #endif +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#define CK_USE_AMD_MFMA_GFX940 +#endif + // WMMA instruction #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_WMMA @@ -80,13 +86,15 @@ // buffer atomic add: floating point #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 -#elif defined(__gfx908__) || defined(__gfx90a__) // for GPU code +#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__) // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 #else // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 #endif -#if defined(__gfx90a__) // for GPU code +#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__)) // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1 #else #define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0 @@ -165,9 +173,21 @@ // workaround: compiler issue on gfx908 #define CK_WORKAROUND_SWDEV_388832 1 + +// workaround: Grouped Conv2d_bwd_data fails for already implemented instance +#define CK_WORKAROUND_SWDEV_3318619 0 + // flag to enable (1) or disable (0) the debugging output in some kernels #define DEBUG_LOG 0 +// denorm test fix, required to work around dissue +#ifndef CK_WORKAROUND_DENORM_FIX +#define CK_WORKAROUND_DENORM_FIX 0 +#elif +// enable only on MI200 +#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__) +#endif // CK_WORKAROUND_DENORM_FIX + namespace ck { enum struct InMemoryDataOperationEnum diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index e2cbdb733272d37aa5dbc7a746e86911d6b8644f..bd02d5d88a2b0a6f882ac58bfada67649e210bc0 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/host_utility/hip_check_error.hpp b/include/ck/host_utility/hip_check_error.hpp index d3dc8eaf1eb8b87207256ea4a521e23d5a49ca9c..af7bebd9d6afbf59f1852156b1708ed843d9a9f6 100644 --- a/include/ck/host_utility/hip_check_error.hpp +++ b/include/ck/host_utility/hip_check_error.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/host_utility/io.hpp b/include/ck/host_utility/io.hpp index ac8719592db8b48996bfbc65dc656d6d96bde545..55734bab2e469a12cd02c417cbb25d9c8a49727c 100644 --- a/include/ck/host_utility/io.hpp +++ b/include/ck/host_utility/io.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/host_utility/kernel_launch.hpp b/include/ck/host_utility/kernel_launch.hpp index 24f2121674c7d371a7d13c0b1c83e7f0a5e059be..58740b43517b9e75bdab59d2e8134d5c6dd97854 100644 --- a/include/ck/host_utility/kernel_launch.hpp +++ b/include/ck/host_utility/kernel_launch.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/host_utility/stream_utility.hpp b/include/ck/host_utility/stream_utility.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9ab49489bbe59b5152ad61c74622be48c33c28a7 --- /dev/null +++ b/include/ck/host_utility/stream_utility.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/stream_config.hpp" +#include "ck/host_utility/hip_check_error.hpp" + +static inline int getAvailableComputeUnitCount(const StreamConfig& stream_config) +{ + constexpr int MAX_MASK_DWORDS = 64; + + // assume at most 64*32 = 2048 CUs + uint32_t cuMask[MAX_MASK_DWORDS]; + + for(int i = 0; i < MAX_MASK_DWORDS; i++) + cuMask[i] = 0; + + auto countSetBits = [](uint32_t dword) { + int count = 0; + + while(dword != 0) + { + if(dword & 0x1) + count++; + + dword = dword >> 1; + }; + + return (count); + }; + + hip_check_error(hipExtStreamGetCUMask(stream_config.stream_id_, MAX_MASK_DWORDS, &cuMask[0])); + + int ret = 0; + + for(int i = 0; i < MAX_MASK_DWORDS; i++) + ret += countSetBits(cuMask[i]); + + return (ret); +}; diff --git a/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index db8e48df6d4b18e9c8a60b8f3d4382d1c6e649f8..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,275 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP -#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// Number of GEMMs = YTilde * XTilde -// GemmM = C -// GemmN = N * HTildeSlice * WTildeSlice -// GemmK = K * YDotSlice * XDotSlice -template -__host__ __device__ constexpr auto -transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( - const TensorDescriptor& wei_k_y_x_c_grid_desc, - const TensorDescriptor& out_n_ho_wo_k_grid_desc, - const TensorDescriptor& in_n_hi_wi_c_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - Number, - Number, - Number) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto GemmK1 = Number{}; - constexpr auto IYTilde = Number{}; - constexpr auto IXTilde = Number{}; - - const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); - - const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); - const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); - - const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); - - const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); - const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); - const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - const auto YTilde = ConvStrideH / GcdStrideDilationH; - const auto XTilde = ConvStrideW / GcdStrideDilationW; - - const auto YDot = math::integer_divide_ceil(Y, YTilde); - const auto XDot = math::integer_divide_ceil(X, XTilde); - - const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); - const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); - - // only work on HTilde and WTilde that contribute to non-padding area of input tensor - const auto IHTildeSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); - const auto IWTildeSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); - - const auto IHTildeSliceEnd = - math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); - const auto IWTildeSliceEnd = - math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); - - const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; - const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; - - // GemmK is different for each GEMM - const auto YDotSlice = math::integer_divide_ceil(Y - IYTilde, YTilde); - const auto XDotSlice = math::integer_divide_ceil(X - IXTilde, XTilde); - - const auto K1 = GemmK1; - const auto K0 = K / K1; - - // weight tensor - const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( - wei_k_y_x_c_grid_desc, - make_tuple(make_pass_through_transform(K), - make_embed_transform(make_tuple(YDot, YTilde), - make_tuple(ConvStrideH / GcdStrideDilationH, I1)), - make_embed_transform(make_tuple(XDot, XTilde), - make_tuple(ConvStrideW / GcdStrideDilationW, I1)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = - transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(K0, K1)), - make_slice_transform(YDot, I0, YDotSlice), - make_slice_transform(XDot, I0, XDotSlice), - make_freeze_transform(IYTilde), - make_freeze_transform(IXTilde), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<3>{}, - Sequence<2>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0, 1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<>{}, - Sequence<>{}, - Sequence<4>{})); - -#if 1 - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - wei_k0_k1_ydotslice_xdotslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), - make_pass_through_transform(C), - make_pass_through_transform(K1)), - make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); -#else - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - wei_k0_k1_ydotslice_xdotslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), - make_pass_through_transform(C), - make_pass_through_transform(K1)), - make_tuple(Sequence<0, 2, 3>{}, Sequence<4>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); -#endif - - // output tensor - // this add padding check - const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( - out_n_ho_wo_k_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Ho, I0, I0), - make_pad_transform(Wo, I0, I0), - make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( - out_n_hop_wop_k_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(YDot, HTilde), - make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), - make_embed_transform(make_tuple(XDot, WTilde), - make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), - make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = - transform_tensor_descriptor( - out_n_ydot_htilde_xdot_wtilde_k_grid_desc, - make_tuple(make_pass_through_transform(N), - make_slice_transform(YDot, I0, YDotSlice), - make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), - make_slice_transform(XDot, I0, XDotSlice), - make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), - make_unmerge_transform(make_tuple(K0, K1))), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5, 6>{})); - -#if 1 - const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, - make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), - make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), - make_pass_through_transform(K1)), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); -#else - const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, - make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), - make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), - make_pass_through_transform(K1)), - make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); -#endif - - // input tensor - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(YTilde, HTilde), - make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(XTilde, WTilde), - make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( - in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_freeze_transform(IYTilde), - make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), - make_freeze_transform(IXTilde), - make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0>{}, - Sequence<>{}, - Sequence<1>{}, - Sequence<>{}, - Sequence<2>{}, - Sequence<3>{})); - - const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - in_n_htildeslice_wtildeslice_c_grid_desc, - make_tuple(make_pass_through_transform(C), - make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice))), - make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc, - out_gemmk0_gemmn_gemmk1_grid_desc, - in_gemmm_gemmn_grid_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 5391b595b5c1e9b5b7395d1a676efcd87b9f2e6e..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,355 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP -#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// A: out -// B: wei -// C: in -// Number of GEMMs = YTilde * XTilde -// GemmM = N * HTildeSlice * WTildeSlice -// GemmN = C -// GemmK = K * YDotSlice * XDotSlice -template -__host__ __device__ constexpr auto -transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( - const TensorDescriptor& out_n_ho_wo_k_grid_desc, - const TensorDescriptor& wei_k_y_x_c_grid_desc, - const TensorDescriptor& in_n_hi_wi_c_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - IYTilde i_ytilde, - IXTilde i_xtilde, - Number) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto GemmK1 = Number{}; - - const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); - - const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); - const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); - - const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); - - const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); - const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); - const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - const auto YTilde = ConvStrideH / GcdStrideDilationH; - const auto XTilde = ConvStrideW / GcdStrideDilationW; - - const auto YDot = math::integer_divide_ceil(Y, YTilde); - const auto XDot = math::integer_divide_ceil(X, XTilde); - - const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); - const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); - - // only work on HTilde and WTilde that contribute to non-padding area of input tensor - const auto IHTildeSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); - const auto IWTildeSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); - - const auto IHTildeSliceEnd = - math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); - const auto IWTildeSliceEnd = - math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); - - const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; - const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; - - // GemmK is different for each GEMM - const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); - const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); - - const auto K1 = GemmK1; - const auto K0 = K / K1; - - // A: output tensor - // this add padding check - const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( - out_n_ho_wo_k_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Ho, I0, I0), - make_pad_transform(Wo, I0, I0), - make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( - out_n_hop_wop_k_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(YDot, HTilde), - make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), - make_embed_transform(make_tuple(XDot, WTilde), - make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), - make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = - transform_tensor_descriptor( - out_n_ydot_htilde_xdot_wtilde_k_grid_desc, - make_tuple(make_pass_through_transform(N), - make_slice_transform(YDot, I0, YDotSlice), - make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), - make_slice_transform(XDot, I0, XDotSlice), - make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), - make_unmerge_transform(make_tuple(K0, K1))), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5, 6>{})); - -#if 1 - const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, - make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), - make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), - make_pass_through_transform(K1)), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); -#else - const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, - make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), - make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), - make_pass_through_transform(K1)), - make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); -#endif - - // B: weight tensor - const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( - wei_k_y_x_c_grid_desc, - make_tuple(make_pass_through_transform(K), - make_embed_transform(make_tuple(YDot, YTilde), - make_tuple(ConvStrideH / GcdStrideDilationH, I1)), - make_embed_transform(make_tuple(XDot, XTilde), - make_tuple(ConvStrideW / GcdStrideDilationW, I1)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = - transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(K0, K1)), - make_slice_transform(YDot, I0, YDotSlice), - make_slice_transform(XDot, I0, XDotSlice), - make_freeze_transform(i_ytilde), - make_freeze_transform(i_xtilde), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<3>{}, - Sequence<2>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0, 1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<>{}, - Sequence<>{}, - Sequence<4>{})); - -#if 1 - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - wei_k0_k1_ydotslice_xdotslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), - make_pass_through_transform(C), - make_pass_through_transform(K1)), - make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); -#else - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - wei_k0_k1_ydotslice_xdotslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), - make_pass_through_transform(C), - make_pass_through_transform(K1)), - make_tuple(Sequence<0, 2, 3>{}, Sequence<4>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); -#endif - - // C: input tensor - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(YTilde, HTilde), - make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(XTilde, WTilde), - make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( - in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_freeze_transform(i_ytilde), - make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), - make_freeze_transform(i_xtilde), - make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0>{}, - Sequence<>{}, - Sequence<1>{}, - Sequence<>{}, - Sequence<2>{}, - Sequence<3>{})); - - const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - in_n_htildeslice_wtildeslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), - make_pass_through_transform(C)), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, - wei_gemmk0_gemmn_gemmk1_grid_desc, - in_gemmm_gemmn_grid_desc); -} - -// A: out -// B: wei -// C: in -// Number of GEMMs = 1 -// GemmM = N * Ho * Wo -// GemmN = C -// GemmK = K -template -__host__ __device__ constexpr auto -transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1( - const TensorDescriptor& out_n_ho_wo_k_grid_desc, - const TensorDescriptor& /* wei_k_y_x_c_grid_desc */, - const TensorDescriptor& in_n_hi_wi_c_grid_desc, - const ConvStrides& conv_strides, - Number) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto GemmK1 = Number{}; - - const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); - - const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto K1 = GemmK1; - const auto K0 = K / K1; - - // A: output tensor - const auto out_gemmk0_gemmm_gemmk1_grid_desc = - transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), - make_tuple(make_pass_through_transform(N * Ho * Wo), - make_unmerge_transform(make_tuple(K0, K1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0, 2>{})); - - // B: weight tensor - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, C)), - make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // C: input tensor - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), - make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_freeze_transform(I0), - make_freeze_transform(I0), - make_merge_transform(make_tuple(N, Ho, Wo)), - make_pass_through_transform(C)), - make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), - make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); - - return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, - wei_gemmk0_gemmn_gemmk1_grid_desc, - in_gemmm_gemmn_grid_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp deleted file mode 100644 index bb1dc239f4d7e6a5abade2dda4a17a98ab7d1a47..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,150 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP -#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// GemmM = K -// GemmK = N * Ho * Wo -// GemmN = C * Y * X -template -__host__ __device__ constexpr auto -transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad( - const TensorDescriptor& wei_k_c_y_x_grid_desc, - const TensorDescriptor& in_n_c_hi_wi_grid_desc, - const TensorDescriptor& out_n_k_ho_wo_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - Number, - GemmKBatchType GemmKBatch, - GemmKPadType GemmKPad) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto GemmK1 = Number{}; - - const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); - const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); - - const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); - const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); - - const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); - const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); - - const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); - const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - const auto GemmM = K; - const auto GemmN = C * Y * X; - const auto GemmKTotal = N * Ho * Wo; - const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1); - - // A: output tensor - const auto out_gemmktotal_gemmm_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), - make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // B: input tensor - const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor( - in_n_c_hi_wi_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor( - in_n_c_hip_wip_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - const auto in_gemmktotal_gemmn_grid_desc = - transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc, - make_tuple(make_merge_transform(make_tuple(C, Y, X)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp deleted file mode 100644 index ca530934e49dcccea388fd8501e5bf046f1c94af..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,132 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP -#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// GemmM = K -// GemmK = N * Ho * Wo -// GemmN = C * Y * X -template -__host__ __device__ constexpr auto -transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( - const TensorDescriptor& wei_k_c_y_x_grid_desc, - const TensorDescriptor& in_n_c_hi_wi_grid_desc, - const TensorDescriptor& out_n_k_ho_wo_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - Number) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto GemmK1 = Number{}; - - const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); - const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); - - const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); - const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); - - const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); - const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); - - const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); - const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - const auto GemmM = K; - const auto GemmN = C * Y * X; - const auto GemmK = N * Ho * Wo; - const auto GemmK0 = GemmK / GemmK1; - - // weight tensor - const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // input tensor - const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor( - in_n_c_hi_wi_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor( - in_n_c_hip_wip_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - const auto in_gemmk_gemmn_grid_desc = - transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc, - make_tuple(make_merge_transform(make_tuple(C, Y, X)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto in_gemmk0_gemmn_gemmk1_grid_desc = - transform_tensor_descriptor(in_gemmk_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // output tensor - const auto out_gemmk_gemmm_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), - make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto out_gemmk0_gemmm_gemmk1_grid_desc = - transform_tensor_descriptor(out_gemmk_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index e960f90c4bb1974ddfaaf9faf6a12d75d1db2f4d..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,150 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP -#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// A: in -// B: wei -// C: out -// GemmM = N * Ho * Wo -// GemmN = K -// GemmK = Y * X * C -template -__host__ __device__ constexpr auto -transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad( - const TensorDescriptor& in_n_hi_wi_c_grid_desc, - const TensorDescriptor& wei_k_y_x_c_grid_desc, - const TensorDescriptor& out_n_ho_wo_k_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - Number, - GemmKBatchType GemmKBatch, - GemmKPadType GemmKPad) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto GemmK1 = Number{}; - - const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); - - const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); - const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); - - const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); - - const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); - const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - const auto GemmM = Y * X * C; - const auto GemmN = K; - const auto GemmKTotal = N * Ho * Wo; - const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1); - - // A: input tensor - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmktotal_gemmm_grid_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y, X, C)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto in_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // B: output tensor - const auto out_gemmktotal_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); - - const auto out_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - return make_tuple(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 052bab423db04740c88721026a3cceb3fa77c292..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,135 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP -#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// A: in -// B: wei -// C: out -// GemmM = N * Ho * Wo -// GemmN = K -// GemmK = Y * X * C -template -__host__ __device__ constexpr auto -transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( - const TensorDescriptor& in_n_hi_wi_c_grid_desc, - const TensorDescriptor& wei_k_y_x_c_grid_desc, - const TensorDescriptor& out_n_ho_wo_k_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - Number) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto GemmK1 = Number{}; - - const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); - - const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); - const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); - - const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); - - const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); - const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - const auto GemmM = Y * X * C; - const auto GemmN = K; - const auto GemmK = N * Ho * Wo; - const auto GemmK0 = GemmK / GemmK1; - - // A: input tensor - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmk_gemmm_grid_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y, X, C)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto in_gemmk0_gemmm_gemmk1_grid_desc = - transform_tensor_descriptor(in_gemmk_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // B: output tensor - const auto out_gemmk_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), - make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto out_gemmk0_gemmn_gemmk1_grid_desc = - transform_tensor_descriptor(out_gemmk_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, - out_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index c301a9e0c67b11c42fbe2a0a99c4f764049a5222..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,147 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP -#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// A: out -// B: in -// C: wei -// GemmM = K -// GemmN = Y * X * C -// GemmKTotal = N * Ho * Wo -template -__host__ __device__ constexpr auto -transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad( - const TensorDescriptor& in_n_hi_wi_c_grid_desc, - const TensorDescriptor& wei_k_y_x_c_grid_desc, - const TensorDescriptor& out_n_ho_wo_k_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - Number, - GemmKBatchType GemmKBatch, - GemmKPadType GemmKPad) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto GemmK1 = Number{}; - - const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); - - const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); - const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); - - const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); - - const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); - const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - const auto GemmM = K; - const auto GemmN = Y * X * C; - const auto GemmKTotal = N * Ho * Wo; - const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1); - - // A: output tensor - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); - - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // B: input tensor - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmktotal_gemmn_grid_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y, X, C)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); - - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp b/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp index 41267536551ea298e111f6d93b8e6f3f8f9ed475..6b118e972e67897f2c9e0cad3a8f959760f82f6c 100644 --- a/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_TRANSFORM_FORWARD_CONVOLUTION3D_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP #define CK_TRANSFORM_FORWARD_CONVOLUTION3D_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP diff --git a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 381f9ac9d6f5a9b000eb44839c7198e4dc9f5d9b..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,260 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP -#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// GemmM = K -// GemmN = N * Ho * Wo -// GemmK = C * Y * X -template -__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad( - const TensorDescriptor& wei_k_c_y_x_global_desc, - const TensorDescriptor& in_n_c_hi_wi_global_desc, - const TensorDescriptor& out_n_k_ho_wo_global_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); - const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); - - const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); - const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); - - const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); - const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); - - const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); - const auto X = wei_k_c_y_x_global_desc.GetLength(I3); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - // weight tensor - const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - // input tensor - const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( - in_n_c_hi_wi_global_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( - in_n_c_hip_wip_global_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - const auto in_gemmk_gemmn_global_desc = - transform_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc, - make_tuple(make_merge_transform(make_tuple(C, Y, X)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // output tensor - const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), - make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return make_tuple( - wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc); -} - -template -__host__ __device__ constexpr auto -transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad( - const TensorDescriptor& wei_k_c_y_x_global_desc, - const TensorDescriptor& in_n_c_hi_wi_global_desc, - const TensorDescriptor& out_n_k_ho_wo_global_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); - const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); - - const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); - const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); - - const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); - const auto X = wei_k_c_y_x_global_desc.GetLength(I3); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - assert(InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && InRightPadW == 0); - - // weight tensor - const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - // input tensor - const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( - in_n_c_hi_wi_global_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - const auto in_gemmk_gemmn_global_desc = - transform_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc, - make_tuple(make_merge_transform(make_tuple(C, Y, X)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // output tensor - const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), - make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return make_tuple( - wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc); -} - -template -__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1( - const TensorDescriptor& wei_k_c_y_x_global_desc, - const TensorDescriptor& in_n_c_hi_wi_global_desc, - const TensorDescriptor& out_n_k_ho_wo_global_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); - const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); - - const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); - const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); - - const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); - const auto X = wei_k_c_y_x_global_desc.GetLength(I3); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 && - ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && - InRightPadW == 0); - - // weight tensor - const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, C)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - // input tensor - const auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor( - in_n_c_hi_wi_global_desc, - make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // output tensor - const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), - make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return make_tuple( - wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index ebfaabb03ebdda41d086e4b80c5856156810992a..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,179 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP -#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// GemmM = K -// GemmN = N * Ho * Wo -// GemmK = C * Y * X -template -__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad( - const TensorDescriptor& wei_k_y_x_c_grid_desc, - const TensorDescriptor& in_n_hi_wi_c_grid_desc, - const TensorDescriptor& out_n_ho_wo_k_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); - - const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); - const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); - - const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); - - const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); - const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - // weight tensor - const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - // input tensor - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmk_gemmn_grid_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y, X, C)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // output tensor - const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), - make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - return make_tuple( - wei_gemmk_gemmm_grid_desc, in_gemmk_gemmn_grid_desc, out_gemmm_gemmn_grid_desc); -} - -template -__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1( - const TensorDescriptor& wei_k_y_x_c_grid_desc, - const TensorDescriptor& in_n_hi_wi_c_grid_desc, - const TensorDescriptor& out_n_ho_wo_k_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); - - const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); - - const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); - const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 && - ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && - InRightPadW == 0); - - // weight tensor - const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, C)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - // input tensor - const auto in_gemmk_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)), - make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - // output tensor - const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), - make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - return make_tuple( - wei_gemmk_gemmm_grid_desc, in_gemmk_gemmn_grid_desc, out_gemmm_gemmn_grid_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 6e576d69f5f4b5a9a60a05f84edb43fec990e91a..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,132 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP -#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// GemmM = K -// GemmN = N * Ho * Wo -// GemmK = C * Y * X -template -__host__ __device__ constexpr auto -transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( - const TensorDescriptor& wei_k_c_y_x_grid_desc, - const TensorDescriptor& in_n_c_hi_wi_grid_desc, - const TensorDescriptor& out_n_k_ho_wo_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - Number) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto GemmK1 = Number{}; - - const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); - const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); - - const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); - const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); - - const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); - const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); - - const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); - const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - const auto GemmM = K; - const auto GemmN = N * Ho * Wo; - const auto GemmK = C * Y * X; - const auto GemmK0 = GemmK / GemmK1; - - // weight tensor - const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = - transform_tensor_descriptor(wei_gemmk_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // input tensor - const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor( - in_n_c_hi_wi_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor( - in_n_c_hip_wip_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - const auto in_gemmk_gemmn_grid_desc = - transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc, - make_tuple(make_merge_transform(make_tuple(C, Y, X)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmk0_gemmn_gemmk1_grid_desc = - transform_tensor_descriptor(in_gemmk_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // output tensor - const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), - make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 13e1bf251abfdc3716bdd17f331f82fd7bdfad41..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,132 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP -#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// GemmM = K -// GemmN = N * Ho * Wo -// GemmK = C * Y * X -template -__host__ __device__ constexpr auto -transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad( - const TensorDescriptor& wei_k_y_x_c_grid_desc, - const TensorDescriptor& in_n_hi_wi_c_grid_desc, - const TensorDescriptor& out_n_ho_wo_k_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - Number) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto GemmK1 = Number{}; - - const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); - - const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); - const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); - - const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); - - const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); - const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - const auto GemmM = K; - const auto GemmN = N * Ho * Wo; - const auto GemmK = C * Y * X; - const auto GemmK0 = GemmK / GemmK1; - - // weight tensor - const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = - transform_tensor_descriptor(wei_gemmk_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // input tensor - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmk_gemmn_grid_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y, X, C)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmk0_gemmn_gemmk1_grid_desc = - transform_tensor_descriptor(in_gemmk_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // output tensor - const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), - make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 088d14b2ee472bc049385ad27e45d02df1e3b6c9..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,134 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP -#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// A: in -// B: wei -// C: out -// GemmM = N * Ho * Wo -// GemmN = K -// GemmK = Y * X * C -template -__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk( - const TensorDescriptor& in_n_hi_wi_c_grid_desc, - const TensorDescriptor& wei_k_y_x_c_grid_desc, - const TensorDescriptor& out_n_ho_wo_k_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - Number) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto GemmK1 = Number{}; - - const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); - - const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); - const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); - - const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); - - const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); - const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - const auto GemmM = N * Ho * Wo; - const auto GemmN = K; - const auto GemmK = Y * X * C; - const auto GemmK0 = GemmK / GemmK1; - - // A: input tensor - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmk_gemmm_grid_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y, X, C)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmk0_gemmm_gemmk1_grid_desc = - transform_tensor_descriptor(in_gemmk_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // B: weight tensor - const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = - transform_tensor_descriptor(wei_gemmk_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // C: output tensor - const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), - make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, - wei_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp deleted file mode 100644 index a6785d56df7e70fdb4882f0c431b886d8c87caac..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,135 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP -#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// GemmM0 = 1 -// GemmM1 = K -// GemmN0 = N0 -// GemmN1 = (N / N0) * Ho * Wo -// GemmK0 = (C / C0) * Y * X -// GemmK1 = C0 -template -__host__ __device__ constexpr auto -transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad( - const TensorDescriptor& wei_k_c_y_x_grid_desc, - const TensorDescriptor& in_n_c_hi_wi_grid_desc, - const TensorDescriptor& out_n_k_ho_wo_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const N0Type& N0, - const C0Type& C0) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); - const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); - - const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); - const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); - - const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); - const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); - - const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); - const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - const auto N1 = N / N0; - const auto C1 = C / C0; - - // weight tensor - const auto wei_gk0_gm0_gm1_gk1_grid_desc = - transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), - make_tuple(make_unmerge_transform(make_tuple(I1, K)), - make_unmerge_transform(make_tuple(C0, C1 * Y * X))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1, 2>{}, Sequence<3, 0>{})); - - // input tensor - const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor( - in_n_c_hi_wi_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n0_n1_c0_c1_y_ho_x_wo_grid_desc = transform_tensor_descriptor( - in_n_c_hip_wip_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(N0, N1)), - make_unmerge_transform(make_tuple(C0, C1)), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{})); - - const auto in_gk0_gn0_gn1_gk1_grid_desc = transform_tensor_descriptor( - in_n0_n1_c0_c1_y_ho_x_wo_grid_desc, - make_tuple(make_merge_transform(make_tuple(C1, Y, X)), - make_pass_through_transform(N0), - make_merge_transform(make_tuple(N1, Ho, Wo)), - make_pass_through_transform(C0)), - make_tuple(Sequence<3, 4, 6>{}, Sequence<0>{}, Sequence<1, 5, 7>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - // output tensor - const auto out_n_k_howo_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)); - - const auto out_n0_n1_1_k_howo_grid_desc = - transform_tensor_descriptor(out_n_k_howo_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(N0, N1)), - make_unmerge_transform(make_tuple(I1, K)), - make_pass_through_transform(Ho * Wo)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{})); - - const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_tensor_descriptor( - out_n0_n1_1_k_howo_grid_desc, - make_tuple(make_pass_through_transform(I1), - make_pass_through_transform(K), - make_pass_through_transform(N0), - make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))), - make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - return make_tuple( - wei_gk0_gm0_gm1_gk1_grid_desc, in_gk0_gn0_gn1_gk1_grid_desc, out_gm0_gm1_gn0_gn1_grid_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/stream_config.hpp b/include/ck/stream_config.hpp index 70ca34555a01436c79ba244ac03572bb4e9520b4..505a602b240428bcd1f0f81018fef0c4716c80b7 100644 --- a/include/ck/stream_config.hpp +++ b/include/ck/stream_config.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor/static_tensor.hpp b/include/ck/tensor/static_tensor.hpp index fee679f91060aca623ab498362de78fd8118fee2..d719ef9760d79297600d7524167eba78cd137831 100644 --- a/include/ck/tensor/static_tensor.hpp +++ b/include/ck/tensor/static_tensor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_STATIC_TENSOR_HPP #define CK_STATIC_TENSOR_HPP diff --git a/include/ck/tensor_description/cluster_descriptor.hpp b/include/ck/tensor_description/cluster_descriptor.hpp index 0c9ea2ff2a0d73b793008a954eaf7293b33ade08..2dfcad8e042e548d15b4bd963fe615e39c04eff1 100644 --- a/include/ck/tensor_description/cluster_descriptor.hpp +++ b/include/ck/tensor_description/cluster_descriptor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_description/multi_index_transform.hpp b/include/ck/tensor_description/multi_index_transform.hpp index 4e4d7593e9083f321f1cabe65faaee3c14259666..6854226dd4da975a8df4e5d516232f69476f515f 100644 --- a/include/ck/tensor_description/multi_index_transform.hpp +++ b/include/ck/tensor_description/multi_index_transform.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_description/multi_index_transform_helper.hpp b/include/ck/tensor_description/multi_index_transform_helper.hpp index 044a90370095eb53b94b5c2fba81abdbeae82c00..af0a8a34d0e48bde494dacfff5bbb17eda5eb479 100644 --- a/include/ck/tensor_description/multi_index_transform_helper.hpp +++ b/include/ck/tensor_description/multi_index_transform_helper.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_description/tensor_adaptor.hpp b/include/ck/tensor_description/tensor_adaptor.hpp index d42e0a6ff08f60eb1cf6e0f241f93e528bf3514b..3ffac32469a8d97d678bc6b5fe62a7cc5e0b24ab 100644 --- a/include/ck/tensor_description/tensor_adaptor.hpp +++ b/include/ck/tensor_description/tensor_adaptor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_description/tensor_descriptor.hpp b/include/ck/tensor_description/tensor_descriptor.hpp index f07d5b1733d9cc96dfac9f18cbbe30510760cabc..f1df2eedd466c81e4b7938c4808075869e6309fd 100644 --- a/include/ck/tensor_description/tensor_descriptor.hpp +++ b/include/ck/tensor_description/tensor_descriptor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_description/tensor_descriptor_helper.hpp b/include/ck/tensor_description/tensor_descriptor_helper.hpp index 461aae72cf7b1879c90e473adb3814e1c4875b52..f3ac041bf9b8dc03802d170fbd7bf08ce6ab9cb1 100644 --- a/include/ck/tensor_description/tensor_descriptor_helper.hpp +++ b/include/ck/tensor_description/tensor_descriptor_helper.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_description/tensor_space_filling_curve.hpp b/include/ck/tensor_description/tensor_space_filling_curve.hpp index 17c9100b9fd76418d562c5c175ad24062d8c3415..9a326092d2e0fd8392ec42a8c0a82b4167076373 100644 --- a/include/ck/tensor_description/tensor_space_filling_curve.hpp +++ b/include/ck/tensor_description/tensor_space_filling_curve.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp index 8b1b7be11ef7b3cd19f11bd400b8a96146921bf7..b3caa3214a80893ad1acbf7b5f385bcbc513670b 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp index 33120bd86ff01da47f7b80c593bc966785cb3711..b0143366c1de5f8e6ec52cceefd3c20e4e77cbb7 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP #define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp index f45655721fe4de11b50b91e0dc5d22790ff73bea..0d092da5168c5d5629eecc2aac7cbdd53e210277 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP #define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index d75f37d7b3915408b05c1bc77b478ab3e41dc8b0..5ec964bd3fdbda9319d1c6eab0ccb207775a6137 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 5328dfde9bc09039af03892467d94545002c56fd..d5a64d7aa6fe6c96f384d3f00af642231b6b2c53 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp index aa814ab00939d0c36866af01b6cf2057dcbd5121..8ae1ba3f34c10d8359b470ed1172f1b70a7fa8b5 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp b/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp index 7e62a822a8f042d238905295b7f88c1d9bdd88f8..82bcff694757e2a4e1c0c8e9fee02c6bba51ca3e 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp index 03e4d42d3a1f9eb1b180c95368b905e619e67110..d8da134a3415a6976b27d0a4fdd7f13798d0245b 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp b/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp index 316508651e4bd4dad2e47edf020a4145176fa753..a3813ea248608252bbbfe49e52c96e66534bf290 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp b/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp index 2163ad32383d4f940ab4cf3b041ab2d5c61456cc..6c13513cfb01a2be9b6a8acb161a843ab4c09b7d 100644 --- a/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp +++ b/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp index 04ad75bd7de4a9657c2aaaf21b40205c3ff9c7d8..c8690e5f68765622222ac3a8cff7c59e6fc06e91 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp index 5c47a49b38b6a30eebc9190cc338d4dbc0bc8524..905a59f56e3b42bfb4b2bdf8652a3e81b2872076 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp index aa33fc083f15cfd8904c8fa086a866dbc7817e7c..17110c8358dbaf7e602e7f9822968e08f80e14ef 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp index eb5f589a4ada976c9be4a5001fb6fc288c7e9a43..9a5317dd126a818bc88043408427f0f105cb7f6e 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp index 3bd7806389b59542b191e61f5369aebf7300fc6a..993d90e356def70ba54e357889277188b4f21971 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp index a4a29f5d5edc3b840fc7489d5f706231990cad70..f3263c7216a61528bed5408d0a950f8ae40983e1 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp index 20b2a152b9d276b0bd0f6e1aa5798b4817b013a9..01bb806789c3d9e50f706d7c88982b4a6624ff57 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp index 953ff1e06ed07b24968c7f5c0842161ac66643ed..adfa1689c66509c8c194985365356740d4c90473 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index 5946daf21ec169612c3207282007e05671882f09..198169011107fb0f236d4657c399ff0534ce2c98 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp index 9fcd893c7a8c38d31f480b13eb8e7bfc97ec7fef..ee7af0117d17acc561e5943ac59611e188b6ba4c 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp index e755913280f73e325daf1352324bab8d3df8a3a9..6cc2c7bb2f6c176f2d84fdda4be2140db5564360 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp index af681127f30404f182475039add4443e46213398..91b4b6b91b6b6be5d08d88e5ca68c9b62fb02aea 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp index 116e62c00907e0f72031552f46b8a3ea975a18f5..f18dc3290600e63edb869d93dcb1a903206e32ab 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp index eacc5976d3ea3885a8294f1ee74a249f0a3d52a7..8234e29486b1af8ae0dd43e58d4be9b63b5ed34f 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp index c1f85e575ce4b2e0d70c05010bfca6e9ed9b84a8..09259224e75cc67e092eeeaab1af793b050cc1f0 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp index bde71806daa0543c41813f2afb0dd34f16bfb76e..be8105c96726dfb3d8931b376a668c3296040706 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp b/include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp index d39f3b7cbcfc9ccaccca5f755be747cc4896dc1b..2c0da692570b455da42540bb72793b9e629cb823 100644 --- a/include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp b/include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp index aa93dd9c19d77fbee7ccdafd1e41e4c4f64735cd..e3962e177ee824981c35fbe2ba04664e4124315b 100644 --- a/include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp b/include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp index 8a00fd9db3327a7c536b1ada0ca090bfa847a256..69103b6f44297ed65ed4233e3d458ac6927bcd16 100644 --- a/include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_cgemm.hpp b/include/ck/tensor_operation/gpu/device/device_cgemm.hpp index aedae53800b03cb78e6af5b15c8c6dfcd9792eea..8484212118c796f68d1bf066eab7b8ddbb231797 100644 --- a/include/ck/tensor_operation/gpu/device/device_cgemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_cgemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "device_base.hpp" diff --git a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp index dbc525c099bca293d106ac6d3c37ac72b2d749cc..118ade8978643f10dcfafbefb552fe485d6697e4 100644 --- a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp b/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp index 82054a3c9423f819d17a003b8c22844f0a2273fe..eb1b85ec822aa96822b5a288d55a5d3a0898f1e9 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp b/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp index 4b9881088dde07a62e92c5f77a19edc9c4b7670f..4dc11dbefd73619c39febe31a91a953b25049f05 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp b/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp index 5a627deeb2221f5e271532d45bc8a544c8cceeba..7d3845666cf2289f34b186a5aaa8ee616812f30a 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp b/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp index cc139303c929ae999d182ad09e9a4ef33f5209e1..3a49ac632e7de745e78ad5670dc57f425ab9eb33 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_elementwise.hpp b/include/ck/tensor_operation/gpu/device/device_elementwise.hpp index f9f913a7c1f1a6cdfdb41f075f2609e8e2407d07..db0e4bd83f46ad001492876ff8a712895a873c22 100644 --- a/include/ck/tensor_operation/gpu/device/device_elementwise.hpp +++ b/include/ck/tensor_operation/gpu/device/device_elementwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp b/include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp index 9491a92247c71e3938d0f7d81cc0d82741b693de..c56a947ec9e11a0a5dce2cd39ce86d8e6a3efb8c 100644 --- a/include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp +++ b/include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_gemm.hpp index c0af6f80faf606a22678e516be994a75c1d56eca..adf909821dbbdde12526513a7b6218e8ebbd219d 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp index 4c2161eaed5b9c7d7913686f44c360c044ca3b07..a7f42c3b35e237cb4ad9edcbcd137f453af59051 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp index 9113bb7b7454443cce4871bf263433d4f907bfba..a44356dc2405bfe6ed448a8193d55e86bd29415c 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp index a67a09b87416bf11cc285ab1b1d9dc684b17103f..0258858fe50e2d433e5ddfe007e8d06227b47c0c 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp index f4881e32f620be9c31649085813a257e9e84a598..539e83f7cb8a5dfe99b2ecb3da491f9661330265 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp index fcc088ca43d1b8c6224c0015e3eb7434038af4b2..eaa7671c6424b4ecf5ad87c4e790a9c478645b9b 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp index c701bff57f8eb7db155e9abdfd3ab7210e7eeffd..6407aa7e09b81a4dc09dcb1bdca5701aa86d9e14 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp index d985d0f92e3d89e6942f7bdd42913dbe6ffc643d..d0de101c49c563ae40086f91cfb1232b9c11e81d 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -47,7 +47,8 @@ __global__ void e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -416,7 +417,9 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm #include @@ -31,7 +35,7 @@ struct DeviceGroupedGemm : public BaseOperator { static constexpr index_t NumDTensor = DsDataType::Size(); - static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsisiten NumDTensor"); + static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsistent NumDTensor"); virtual std::unique_ptr MakeArgumentPointer(std::vector& p_a, diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp index b066a4458518b378313225e24817253922962d29..fae65097407cdb26df206308442874c0fc326fe6 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index 70f3d0277dde8478d2a83958f67df47ead59c613..ab6d2716c0445262f0cec5d53a5dafc3ab3c9555 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -43,7 +43,8 @@ __global__ void const B1ElementwiseOperation b1_element_op, const CElementwiseOperation c_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); @@ -610,10 +611,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle some_has_main_k_block_loop |= y; } - hipGetErrorString(hipMemcpy(arg.p_workspace_, - arg.group_kernel_args_.data(), - arg.group_kernel_args_.size() * sizeof(GroupKernelArg), - hipMemcpyHostToDevice)); + hipGetErrorString( + hipMemcpyWithStream(arg.p_workspace_, + arg.group_kernel_args_.data(), + arg.group_kernel_args_.size() * sizeof(GroupKernelArg), + hipMemcpyHostToDevice, + stream_config.stream_id_)); float ave_time = 0; @@ -678,7 +681,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || + ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || + ck::get_device_name() == "gfx942")) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..06d180d30fb0f3c8e3b348dd1668ff582de87493 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp @@ -0,0 +1,39 @@ +#pragma once +#include +#include + +#include "device_grouped_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm +{ + virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_index_pool_bwd.hpp b/include/ck/tensor_operation/gpu/device/device_index_pool_bwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bf81ed9f5b1c2e6b46208ffede15370a3ae78088 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_index_pool_bwd.hpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// For pooling which used indexable operation, such as MaxPool, MinPool...etc +template +struct DeviceIndexPoolBwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_dout, + const void* p_indices, + void* p_din, + index_t dout_length, + index_t din_length, + std::vector window_lengths, + std::vector window_strides) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_multiple_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_multiple_reduce.hpp index ee4b53e2fcc43f44aba29c3409963a1b94834cbc..f68022ca04de757fee1240ced4b504601b183d91 100644 --- a/include/ck/tensor_operation/gpu/device/device_multiple_reduce.hpp +++ b/include/ck/tensor_operation/gpu/device/device_multiple_reduce.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_normalization.hpp b/include/ck/tensor_operation/gpu/device/device_normalization.hpp index 03601ce8312ccea06d60236779aa78a9ea306e01..1f178f9fcb65ffdd7ab09146d94131c1d5c993f2 100644 --- a/include/ck/tensor_operation/gpu/device/device_normalization.hpp +++ b/include/ck/tensor_operation/gpu/device/device_normalization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_permute.hpp b/include/ck/tensor_operation/gpu/device/device_permute.hpp index 9daa2be37338c8309bd2131f263d324fe15b7ab2..c994cf02c6aab5b85fcd2fcc29f2716de28b861d 100644 --- a/include/ck/tensor_operation/gpu/device/device_permute.hpp +++ b/include/ck/tensor_operation/gpu/device/device_permute.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp b/include/ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp deleted file mode 100644 index 3b376c6f73f066729fa118013970682ca31b340e..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include - -#include "ck/tensor_operation/gpu/device/device_base.hpp" -#include "ck/utility/reduction_enums.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -template -struct DevicePool2dFwd : public BaseOperator -{ - virtual std::unique_ptr - MakeArgumentPointer(const void* in_dev, - void* out_dev, - void* out_indices_dev, - ck::index_t N, - ck::index_t C, - std::array input_spatial_lengths, - std::array window_spatial_lengths, - std::array output_spatial_lengths, - std::array window_strides, - std::array input_left_pads, - std::array input_right_pads) = 0; - - virtual std::unique_ptr MakeInvokerPointer() = 0; -}; - -template -using DevicePool2dFwdPtr = std::unique_ptr>; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp b/include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8b227fdfbe62798f32c98561f5235ef0c5263de4 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/utility/reduction_enums.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DevicePoolFwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in_dev, + void* p_out_dev, + void* p_out_indices_dev, + std::vector input_lengths, + std::vector window_lengths, + std::vector output_lengths, + std::vector input_stride, + std::vector output_stride, + std::vector indices_stride, + std::vector window_strides, + std::vector input_left_pads, + std::vector input_right_pads, + std::vector pooling_dims) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_put_element.hpp b/include/ck/tensor_operation/gpu/device/device_put_element.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9186827491c04715926fd6a981b1f7f1e626f55e --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_put_element.hpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/utility/reduction_enums.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// output[indices] = input +template +struct DevicePutElement : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_input, + const void* p_indices, + void* p_output, + index_t input_length, + index_t output_length, + ElementwiseOperation elementwise_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_reduce.hpp index c9209f2d7d681b44c8a27dd06626eb8a4b998a39..c2721b18455c6a46e5973fad32767ae43247f59b 100644 --- a/include/ck/tensor_operation/gpu/device/device_reduce.hpp +++ b/include/ck/tensor_operation/gpu/device/device_reduce.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_softmax.hpp b/include/ck/tensor_operation/gpu/device/device_softmax.hpp index 94f788e5177cd78d0ea917b0869c576ba8fe7bfb..1902fd09eec1ed01954b6cad47b0e97ff898eb40 100644 --- a/include/ck/tensor_operation/gpu/device/device_softmax.hpp +++ b/include/ck/tensor_operation/gpu/device/device_softmax.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -18,7 +18,8 @@ template + index_t Rank, + index_t NumReduceDim> struct DeviceSoftmax : public BaseOperator { // @@ -49,8 +50,6 @@ struct DeviceSoftmax : public BaseOperator AccElementwiseOp acc_elementwise_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; - virtual index_t GetRank() const = 0; - virtual index_t GetNumReduceDim() const = 0; }; template -using DeviceSoftmaxPtr = std::unique_ptr< - DeviceSoftmax>; + index_t Rank, + index_t NumReduceDim> +using DeviceSoftmaxPtr = std::unique_ptr>; } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d.hpp index f59e6093e2ae024ccbf8827082c755eba8cee88f..eeccd977ccbf4440a349845826d45d5a3274ad37 100644 --- a/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp index 8eab1cdee565b4334ede5ca9bb7544ea37ff498d..f849ac799d4611757e7a7c4d2949f602f5eb0e44 100644 --- a/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -56,7 +56,8 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = @@ -938,7 +939,9 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || + ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || + ck::get_device_name() == "gfx942")) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp index fc913e9ba03611dfbabea020bd4054c0b953726d..0bb45b18c3e19b2ec5f9347c1e811d8734ee45a9 100644 --- a/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp index 493822aeb27215406a8f7fad56021aca8e4246a7..4d599e8017991b4f1b1905de9acd23a219ac67f9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp index 2237ad944cd61775267851a9a2629efb115f286f..46e71240c1f5a655e3f6c2b95b1dc9d10214b6f1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -56,7 +56,8 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = @@ -839,7 +840,9 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || + ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || + ck::get_device_name() == "gfx942")) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp index 01f5e17d91436747d0b3bf052061741b34e0b46a..fc080df5f5d096a7d5b2166da28786bb977127e7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp @@ -74,7 +74,8 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp index 3b87e56337f95e2463a8211c961504fceb764d8b..7012584aabfb50f79408ccb663f4cd427d1ccfe2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -60,7 +60,8 @@ __global__ void const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -588,7 +589,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/* + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for + * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the + * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch. + */ + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_dl_multiple_d( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, + const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2CTileMap block_2_ctile_map) +{ +// TODO: Enable for gfx90a after complier fix +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || defined(__gfx1101__) || \ + defined(__gfx1102__)) + + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = DsGridDesc_M0_M10_M11_N0_N10_N11::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + GridwiseGemm::Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m0_m1_k1, + b_grid_desc_k0_n0_n1_k1, + ds_grid_desc_m0_m10_m11_n0_n10_n11, + e_grid_desc_m0_m10_m11_n0_n10_n11, + block_2_ctile_map, + integral_constant{}, + integral_constant{}); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_k0_m0_m1_k1; + ignore = b_grid_desc_k0_n0_n1_k1; + ignore = ds_grid_desc_m0_m10_m11_n0_n10_n11; + ignore = e_grid_desc_m0_m10_m11_n0_n10_n11; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; + +#endif +} + +template && + is_same_v, + bool> = false> +struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD + +{ + using DeviceOp = DeviceBatchedGemmMultipleD_Dl; + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + template + static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {})); + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + std::array BatchStrideDs, + index_t BatchStrideE) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideDs_(BatchStrideDs), + BatchStrideE_(BatchStrideE) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + std::array ds_offset; + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_offset[i] = g_idx * static_cast(BatchStrideDs_[i]); + }); + return ds_offset; + } + + __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + std::array BatchStrideDs_; + index_t BatchStrideE_; + }; + + // GridwiseGemm + using GridwiseGemm = + GridwiseGemmDlMultipleD_km_kn_mn; + + using AGridDesc_K0_M0_M1_K1 = + decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); + using BGridDesc_K0_N0_N1_K1 = + decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); + using DsGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(DsGridDesc_M_N{})); + using EGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(EGridDesc_M_N{})); + using DefaultBlock2CTileMap = + decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{})); + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + K_(K), + Batch_(Batch), + a_grid_desc_k0_m0_m1_k1_{}, + b_grid_desc_k0_n0_n1_k1_{}, + e_grid_desc_m0_m10_m11_n0_n10_n11_{}, + compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideDs, BatchStrideE}, + block_2_ctile_map_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + a_grid_desc_k0_m_k1_ = + DeviceBatchedGemmMultipleD_Dl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = + DeviceBatchedGemmMultipleD_Dl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(M, N, StrideDs[i]); + }); + e_grid_desc_m_n_ = + DeviceBatchedGemmMultipleD_Dl::MakeEGridDescriptor_M_N(M, N, StrideE); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, e_grid_desc_m_n_)) + { + a_grid_desc_k0_m0_m1_k1_ = + GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1_); + b_grid_desc_k0_n0_n1_k1_ = + GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1_); + + ds_grid_desc_m0_m10_m11_n0_n10_n11_ = + GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(ds_grid_desc_m_n_); + + e_grid_desc_m0_m10_m11_n0_n10_n11_ = + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(e_grid_desc_m_n_); + + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + index_t K_; + + // Batch + index_t Batch_; + + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_; + BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_; + DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11_; + EGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11_; + + // for calculating batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + DefaultBlock2CTileMap block_2_ctile_map_; + + // TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being. + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceBatchedGemmMultipleD_Dl::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + { + std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{" + << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{" + << arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", " + << arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_)) + { + throw std::runtime_error( + "wrong! GridwiseGemmDlMultipleD_km_kn_mn has invalid setting"); + } + + const index_t grid_size = + GridwiseGemm::CalculateGridSize(arg.e_grid_desc_m_n_.GetLength(I0), + arg.e_grid_desc_m_n_.GetLength(I1)) * + arg.Batch_; + + auto launch_kernel = [&](auto has_main_k_block_loop, + auto has_double_tail_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + constexpr bool has_double_loop = has_double_tail_k_block_loop.value; + + const auto kernel = + kernel_gemm_dl_multiple_d; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.Batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.ds_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.e_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_ctile_map_); + }; + + const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0); + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + const bool has_double_tail_k_block_loop = + GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // TODO: Enable for gfx90a after complier fix + if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx908" || + ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx940" || + ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || + ck::get_device_name() == "gfx1102") + { + bool pass = true; + pass = pass && arg.K_ % K1 == 0; + + pass = pass && GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.e_grid_desc_m_n_); + + return pass; + } + else + { + return false; + } + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + Batch, + StrideA, + StrideB, + StrideDs, + StrideE, + BatchStrideA, + BatchStrideB, + BatchStrideDs, + BatchStrideE, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + const std::array& StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + Batch, + StrideA, + StrideB, + StrideDs, + StrideE, + BatchStrideA, + BatchStrideB, + BatchStrideDs, + BatchStrideE, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmMultipleD_Dl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << M1PerThread << ", " + << N1PerThread << ", " + << KPerThread + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp index 4fc8e69d2c68bd8071c33d896af046604f646e7b..d259ebffae47f98f0e482b9d8b0e480d4a4079ed 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -68,7 +68,8 @@ __global__ void const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -804,7 +805,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || + ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || + ck::get_device_name() == "gfx942")) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp index 080e26ea89277c5d71362291381435fe1682b3b7..1d52416d4d8d875ac7e6dd9dcbf3ef8f98932b1c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -59,7 +59,8 @@ __global__ void const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index 037867d5f45f2e33143e77a2a74690f34af9dc74..73b12bfb4747f2b3971c9a30300c0e4f5745291e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -67,7 +67,8 @@ __global__ void const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const C0MatrixMask c0_matrix_mask) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -196,7 +197,8 @@ template + int D0sTransferSrcScalarPerVector = 4, + LoopScheduler LoopSched = LoopScheduler::Default> struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle : public DeviceBatchedGemmSoftmaxGemmPermute; + MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, + D0sTransferSrcScalarPerVector>; // Argument // FIXME: constness @@ -529,6 +532,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle using D0DataType = remove_cvref_t>; // D0 pointer p_d0s_grid_(i) = static_cast(p_acc0_biases[i]); + // for check + d0s_nl_ns_lengths_strides_[i].push_back( + acc0_biases_gs_ms_ns_lengths[i][NumDimG + NumDimM]); + d0s_nl_ns_lengths_strides_[i].push_back( + acc0_biases_gs_ms_ns_strides[i][NumDimG + NumDimM]); }); if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, @@ -607,6 +615,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle std::vector b_nz_kz_strides_; std::vector b1_nz_kz_strides_; std::vector c_mz_gemm1nz_strides_; + std::array, NumD0Tensor> d0s_nl_ns_lengths_strides_; index_t batch_count_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; @@ -714,7 +723,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle arg.Print(); #endif - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || + ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || + ck::get_device_name() == "gfx942")) { return false; } @@ -770,6 +781,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle { return false; } + for(int i = 0; i < NumD0Tensor; i++) + { + if(arg.d0s_nl_ns_lengths_strides_[i][1] == 1 && + arg.d0s_nl_ns_lengths_strides_[i][0] % D0sTransferSrcScalarPerVector != 0) + { + return false; + } + if(arg.d0s_nl_ns_lengths_strides_[i][1] != 1 && D0sTransferSrcScalarPerVector != 1) + { + return false; + } + } return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index 1f21f2d7122ca7b31553c915f6698f249c3f0aaf..2405640ca94bf0aae26d99b327e4d228313cafd9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -62,7 +62,8 @@ __global__ void const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const C0MatrixMask c0_matrix_mask) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -612,7 +613,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || + ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || + ck::get_device_name() == "gfx942")) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp index 48a224456052fd98184601db84428eb2271fa695..b6a0567fee5277bab30a74a9160591144bae60ae 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -45,74 +45,46 @@ namespace device { * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). * */ -template +template __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_xdlops_v2r3( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const index_t batch_count, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const Block2CTileMap block_2_ctile_map) + kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + __builtin_amdgcn_readfirstlane(get_grid_size() / karg.Batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + static_cast(karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + static_cast(karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + static_cast(karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_c_grid + c_batch_offset, + const auto a_grid_desc_k0_m_k1 = + amd_wave_read_first_lane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1( + karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA)); + const auto b_grid_desc_k0_n_k1 = + amd_wave_read_first_lane(GridwiseGemm::MakeBGridDescriptor_K0_N_K1( + karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB)); + const auto c_grid_desc_m_n = amd_wave_read_first_lane(GridwiseGemm::MakeCGridDescriptor_M_N( + karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC)); + + GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + c_batch_offset, p_shared, a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + c_grid_desc_m_n); #else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_c_grid; - ignore = batch_count; - ignore = a_grid_desc_k0_m_k1; - ignore = b_grid_desc_k0_n_k1; - ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2; - ignore = a_element_op; - ignore = b_element_op; - ignore = c_element_op; - ignore = compute_ptr_offset_of_batch; - ignore = block_2_ctile_map; + ignore = karg; #endif } @@ -170,93 +142,6 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm{}; - static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) - { - assert(K % K1 == 0); - - const index_t K0 = K / K1; - - const auto a_grid_desc_m_k = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); - } - }(); - - const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; - - const auto a_grid_desc_k0_mp_k1 = - transform_tensor_descriptor(a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_right_pad_transform(M, PadM)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_k0_mp_k1; - } - - static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) - { - assert(K % K1 == 0); - - const index_t K0 = K / K1; - - const auto b_grid_desc_k_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); - } - }(); - - const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; - - const auto b_grid_desc_k0_np_k1 = - transform_tensor_descriptor(b_grid_desc_k_n, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_right_pad_transform(N, PadN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_k0_np_k1; - } - - static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) - { - const auto c_grid_desc_m_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); - } - }(); - - const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; - const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; - - const auto c_grid_desc_mp_np = transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return c_grid_desc_mp_np; - } - - using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); - using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); - using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); - struct ComputePtrOffsetOfStridedBatch { ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, @@ -288,121 +173,82 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - NumGemmKPrefetchStage, - LoopSched, - PipelineVer>; - - using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); - using Block2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + ALayout, + BLayout, + CLayout, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpecialization::MNPadding, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + NumGemmKPrefetchStage, + LoopSched, + PipelineVer>; + + using Problem = typename GridwiseGemm::Problem; // Argument - struct Argument : public BaseArgument + struct Argument : public Problem, public BaseArgument { - Argument(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, + Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, - index_t Batch, - index_t M01, - index_t N01, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - : p_a_grid_{p_a_grid}, - p_b_grid_{p_b_grid}, - p_c_grid_{p_c_grid}, - Batch_(Batch), - a_grid_desc_k0_m_k1_{ - DeviceBatchedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA)}, - b_grid_desc_k0_n_k1_{ - DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)}, - c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)}, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, - compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideC}, - block_2_ctile_map_{ - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, - M01_{M01}, - N01_{N01}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op}, - kraw_{K} + index_t Batch_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_}, + Batch(Batch_), + compute_ptr_offset_of_batch{BatchStrideA, BatchStrideB, BatchStrideC} { - if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, - b_grid_desc_k0_n_k1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); - } } - // private: - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - CDataType* p_c_grid_; - index_t Batch_; - AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; - CGridDesc_M_N c_grid_desc_m_n_; - CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; - Block2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CElementwiseOperation c_element_op_; - index_t kraw_; + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + index_t Batch; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch; }; // Invoker @@ -410,107 +256,39 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm 0) { - std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) - << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) - << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " - << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + karg.Print(); } -#endif - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_)) + if(!GridwiseGemm::CheckValidity(karg)) { throw std::runtime_error( - "wrong! GridwiseBatchedGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext has invalid setting"); } - const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_; - - const auto K = - arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N); + gdx *= karg.Batch; float ave_time = 0; - if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K)) { - const auto kernel = kernel_batched_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - ComputePtrOffsetOfStridedBatch, - remove_reference_t, - true>; - - ave_time = launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.Batch_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.compute_ptr_offset_of_batch_, - arg.block_2_ctile_map_); + const auto kernel = + kernel_batched_gemm_xdlops_v2r3; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); } else { - const auto kernel = kernel_batched_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - ComputePtrOffsetOfStridedBatch, - remove_reference_t, - false>; - - ave_time = launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.Batch_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.compute_ptr_offset_of_batch_, - arg.block_2_ctile_map_); + const auto kernel = + kernel_batched_gemm_xdlops_v2r3; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); } return ave_time; @@ -530,17 +308,14 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm(static_cast(p_a), static_cast(p_b), @@ -618,12 +385,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(StrideA, I1)); - } - else if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(I1, StrideA)); - } - }(); - - const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; - const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; - - const auto MPad = M - MRaw; - const auto KPad = K - KRaw; - - if constexpr(GemmSpec == GemmSpecialization::MKPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad both M and K - assert(K % AK1 == 0); - - const auto AK0 = K / AK1; - - const auto a_grid_desc_m_k = - transform_tensor_descriptor(a_grid_desc_mraw_kraw, - make_tuple(make_right_pad_transform(MRaw, MPad), - make_right_pad_transform(KRaw, KPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto a_grid_desc_ak0_m_ak1 = - transform_tensor_descriptor(a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - else if constexpr(GemmSpec == GemmSpecialization::MPadding || - GemmSpec == GemmSpecialization::MNPadding) - { - // pad M, but not K - assert(KRaw % AK1 == 0); - - const auto AK0 = KRaw / AK1; - - const auto a_grid_desc_ak0_m_ak1 = - transform_tensor_descriptor(a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), - make_right_pad_transform(MRaw, MPad)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - else if constexpr(GemmSpec == GemmSpecialization::KPadding || - GemmSpec == GemmSpecialization::NKPadding) - { - // pad K, but not M - assert(K % AK1 == 0); - - const auto AK0 = K / AK1; - - const auto a_grid_desc_m_k = transform_tensor_descriptor( - a_grid_desc_mraw_kraw, - make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto a_grid_desc_ak0_m_ak1 = - transform_tensor_descriptor(a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), - make_pass_through_transform(MRaw)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - else - { - // not pad M or K - assert(KRaw % AK1 == 0); - - const auto AK0 = KRaw / AK1; - - const auto a_grid_desc_ak0_m_ak1 = - transform_tensor_descriptor(a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), - make_pass_through_transform(MRaw)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - } - - static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) - { - const auto b_grid_desc_nraw_kraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(I1, StrideB)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(StrideB, I1)); - } - }(); - - const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; - const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; - - const auto NPad = N - NRaw; - const auto KPad = K - KRaw; - - if constexpr(GemmSpec == GemmSpecialization::NKPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad both N and K - assert(K % BK1 == 0); - - const auto BK0 = K / BK1; - - const auto b_grid_desc_n_k = - transform_tensor_descriptor(b_grid_desc_nraw_kraw, - make_tuple(make_right_pad_transform(NRaw, NPad), - make_right_pad_transform(KRaw, KPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto b_grid_desc_bk0_n_bk1 = - transform_tensor_descriptor(b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else if constexpr(GemmSpec == GemmSpecialization::NPadding || - GemmSpec == GemmSpecialization::MNPadding) - { - // pad N, but not K - assert(KRaw % BK1 == 0); - - const auto BK0 = KRaw / BK1; - - const auto b_grid_desc_bk0_n_bk1 = - transform_tensor_descriptor(b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_right_pad_transform(NRaw, NPad)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else if constexpr(GemmSpec == GemmSpecialization::KPadding || - GemmSpec == GemmSpecialization::MKPadding) - { - // pad K, but not N - assert(K % BK1 == 0); - - const auto BK0 = K / BK1; - - const auto b_grid_desc_n_k = transform_tensor_descriptor( - b_grid_desc_nraw_kraw, - make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto b_grid_desc_bk0_n_bk1 = - transform_tensor_descriptor(b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(NRaw)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else - { - // not pad N or K - assert(KRaw % BK1 == 0); - - const auto BK0 = KRaw / BK1; - - const auto b_grid_desc_bk0_n_bk1 = - transform_tensor_descriptor(b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(NRaw)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - } - - static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) - { - const auto c_grid_desc_mraw_nraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), - make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), - make_tuple(I1, StrideC)); - } - }(); - - const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; - const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; - - const auto MPad = M - MRaw; - const auto NPad = N - NRaw; - - if constexpr(GemmSpec == GemmSpecialization::MNPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad M and N - return transform_tensor_descriptor(c_grid_desc_mraw_nraw, - make_tuple(make_right_pad_transform(MRaw, MPad), - make_right_pad_transform(NRaw, NPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else if constexpr(GemmSpec == GemmSpecialization::MPadding || - GemmSpec == GemmSpecialization::MKPadding) - { - // pad M, but not N - return transform_tensor_descriptor( - c_grid_desc_mraw_nraw, - make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else if constexpr(GemmSpec == GemmSpecialization::NPadding || - GemmSpec == GemmSpecialization::NKPadding) - { - // pad N, but not M - return transform_tensor_descriptor( - c_grid_desc_mraw_nraw, - make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else - { - // not pad M or N - return c_grid_desc_mraw_nraw; - } - } - - using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); - using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); - using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); - using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); - // GridwiseGemm using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ALayout, + BLayout, + CLayout, ADataType, // TODO: distinguish A/B datatype GemmAccDataType, CShuffleDataType, @@ -396,10 +130,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, + GemmSpec, InMemoryDataOperationEnum::Set, - AGridDesc_AK0_M_AK1, - BGridDesc_BK0_N_BK1, - CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, @@ -433,108 +165,82 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched>; + using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + // Argument - struct Argument : public BaseArgument + struct Argument : public tensor_operation::device::BaseArgument, public GridwiseGemm::Problem { - Argument(const ADataType* p_a_grid_real, - const ADataType* p_a_grid_imag, - const BDataType* p_b_grid_real, - const BDataType* p_b_grid_imag, - CDataType* p_c_grid_real, - CDataType* p_c_grid_imag, + using Problem = typename GridwiseGemm::Problem; + + Argument(const ADataType* p_a_grid_real_, + const ADataType* p_a_grid_imag_, + const BDataType* p_b_grid_real_, + const BDataType* p_b_grid_imag_, + CDataType* p_c_grid_real_, + CDataType* p_c_grid_imag_, CDataType* p_workspace, - index_t MRaw, - index_t NRaw, - index_t KRaw, - index_t StrideA, - index_t StrideB, - index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - : p_a_grid_real_{p_a_grid_real}, - p_a_grid_imag_{p_a_grid_imag}, - p_b_grid_real_{p_b_grid_real}, - p_b_grid_imag_{p_b_grid_imag}, - p_c_grid_real_{p_c_grid_real}, - p_c_grid_imag_{p_c_grid_imag}, - p_aux_grid_{p_workspace}, - a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, - b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, - c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, - c_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op} + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid_real{p_a_grid_real_}, + p_a_grid_imag{p_a_grid_imag_}, + p_b_grid_real{p_b_grid_real_}, + p_b_grid_imag{p_b_grid_imag_}, + p_c_grid_real{p_c_grid_real_}, + p_c_grid_imag{p_c_grid_imag_}, + p_aux_grid{p_workspace} { - if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, - b_grid_desc_bk0_n_bk1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n_); - } - - const index_t grid_size = block_2_ctile_map_.CalculateGridSize(c_grid_desc_m_n_); + const index_t grid_size = std::get<1>(GridwiseGemm::CalculateGridSize(M_, N_)); if constexpr(is_same::value) { - c_grid_desc_m_ = - DeviceOp::MakeDescriptor_M({MRaw, NRaw}, {StrideC, I1}, grid_size, BlockSize); + c_grid_desc_m = + DeviceOp::MakeDescriptor_M({M_, N_}, {StrideC_, I1}, grid_size, BlockSize); } else if constexpr(is_same::value) { - c_grid_desc_m_ = - DeviceOp::MakeDescriptor_M({MRaw, NRaw}, {I1, StrideC}, grid_size, BlockSize); + c_grid_desc_m = + DeviceOp::MakeDescriptor_M({M_, N_}, {I1, StrideC_}, grid_size, BlockSize); } - p_aux_2_grid_ = p_workspace + c_grid_desc_m_n_.GetElementSpaceSize(); + p_aux_2_grid = p_workspace + GetCElementSpaceSize(M_, N_, StrideC_); } // private: - const ADataType* p_a_grid_real_; - const ADataType* p_a_grid_imag_; - const BDataType* p_b_grid_real_; - const BDataType* p_b_grid_imag_; - CDataType* p_c_grid_real_; - CDataType* p_c_grid_imag_; - CDataType* p_aux_grid_; - CDataType* p_aux_2_grid_; - AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; - BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; - CGridDesc_M_N c_grid_desc_m_n_; - CGridDesc_M c_grid_desc_m_; - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CElementwiseOperation c_element_op_; + const ADataType* p_a_grid_real; + const ADataType* p_a_grid_imag; + const BDataType* p_b_grid_real; + const BDataType* p_b_grid_imag; + CDataType* p_c_grid_real; + CDataType* p_c_grid_imag; + CDataType* p_aux_grid; + CDataType* p_aux_2_grid; + CGridDesc_M c_grid_desc_m; }; // Invoker struct Invoker : public BaseInvoker { - using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_)) + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) { throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } - const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N); - const auto K = - arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1; float ave_time = 0; @@ -578,224 +284,148 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { - const auto kernel = kernel_gemm_xdl_cshuffle_v1< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, - typename GridwiseGemm::DefaultBlock2CTileMap, - true>; - - ave_time += - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_real_, - arg.p_b_grid_real_, - arg.p_aux_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); - - ave_time += - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_imag_, - arg.p_b_grid_imag_, - arg.p_aux_2_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); + const auto kernel = + kernel_gemm_xdl_cshuffle_v1; + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_real, + arg.p_b_grid_real, + arg.p_aux_grid, + arg); + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_imag, + arg.p_b_grid_imag, + arg.p_aux_2_grid, + arg); // c_real = aux - aux_2 ave_time += launch_and_time_kernel( stream_config, subtract_kernel, - dim3(grid_size), + dim3(gdx, gdy, gdz), dim3(BlockSize), 0, - make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_), - make_tuple(arg.c_grid_desc_m_), - make_tuple(const_cast(arg.p_aux_grid_), - const_cast(arg.p_aux_2_grid_)), - make_tuple(arg.p_c_grid_real_), + make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m), + make_tuple(arg.c_grid_desc_m), + make_tuple(const_cast(arg.p_aux_grid), + const_cast(arg.p_aux_2_grid)), + make_tuple(arg.p_c_grid_real), Subtract{}); - ave_time += - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_real_, - arg.p_b_grid_imag_, - arg.p_aux_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); - - ave_time += - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_imag_, - arg.p_b_grid_real_, - arg.p_aux_2_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_real, + arg.p_b_grid_imag, + arg.p_aux_grid, + arg); + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_imag, + arg.p_b_grid_real, + arg.p_aux_2_grid, + arg); // c_imag = aux + aux_2 ave_time += launch_and_time_kernel( stream_config, add_kernel, - dim3(grid_size), + dim3(gdx, gdy, gdz), dim3(BlockSize), 0, - make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_), - make_tuple(arg.c_grid_desc_m_), - make_tuple(const_cast(arg.p_aux_grid_), - const_cast(arg.p_aux_2_grid_)), - make_tuple(arg.p_c_grid_imag_), + make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m), + make_tuple(arg.c_grid_desc_m), + make_tuple(const_cast(arg.p_aux_grid), + const_cast(arg.p_aux_2_grid)), + make_tuple(arg.p_c_grid_imag), Add{}); } else { - const auto kernel = kernel_gemm_xdl_cshuffle_v1< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, - typename GridwiseGemm::DefaultBlock2CTileMap, - false>; - - ave_time += - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_real_, - arg.p_b_grid_real_, - arg.p_aux_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); - - ave_time += - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_imag_, - arg.p_b_grid_imag_, - arg.p_aux_2_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); + const auto kernel = + kernel_gemm_xdl_cshuffle_v1; + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_real, + arg.p_b_grid_real, + arg.p_aux_grid, + arg); + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_imag, + arg.p_b_grid_imag, + arg.p_aux_2_grid, + arg); // c_real = aux - aux_2 ave_time += launch_and_time_kernel( stream_config, subtract_kernel, - dim3(grid_size), + dim3(gdx, gdy, gdz), dim3(BlockSize), 0, - make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_), - make_tuple(arg.c_grid_desc_m_), - make_tuple(const_cast(arg.p_aux_grid_), - const_cast(arg.p_aux_2_grid_)), - make_tuple(arg.p_c_grid_real_), + make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m), + make_tuple(arg.c_grid_desc_m), + make_tuple(const_cast(arg.p_aux_grid), + const_cast(arg.p_aux_2_grid)), + make_tuple(arg.p_c_grid_real), Subtract{}); - ave_time += - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_real_, - arg.p_b_grid_imag_, - arg.p_aux_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); - - ave_time += - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_imag_, - arg.p_b_grid_real_, - arg.p_aux_2_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_real, + arg.p_b_grid_imag, + arg.p_aux_grid, + arg); + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_imag, + arg.p_b_grid_real, + arg.p_aux_2_grid, + arg); // c_imag = aux + aux_2 ave_time += launch_and_time_kernel( stream_config, add_kernel, - dim3(grid_size), + dim3(gdx, gdy, gdz), dim3(BlockSize), 0, - make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_), - make_tuple(arg.c_grid_desc_m_), - make_tuple(const_cast(arg.p_aux_grid_), - const_cast(arg.p_aux_2_grid_)), - make_tuple(arg.p_c_grid_imag_), + make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m), + make_tuple(arg.c_grid_desc_m), + make_tuple(const_cast(arg.p_aux_grid), + const_cast(arg.p_aux_2_grid)), + make_tuple(arg.p_c_grid_imag), Add{}); } @@ -818,10 +448,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); + return GridwiseGemm::CheckValidity(arg); } // polymorphic @@ -837,15 +464,15 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle CDataType* p_c_real, CDataType* p_c_imag, CDataType* p_workspace, - index_t MRaw, - index_t NRaw, - index_t KRaw, + index_t M, + index_t N, + index_t K, index_t StrideA, index_t StrideB, index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) { return Argument{p_a_real, p_a_imag, @@ -854,15 +481,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle p_c_real, p_c_imag, p_workspace, - MRaw, - NRaw, - KRaw, + M, + N, + K, StrideA, StrideB, - StrideC, - a_element_op, - b_element_op, - c_element_op}; + StrideC}; } static auto MakeInvoker() { return Invoker{}; } @@ -875,15 +499,15 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle void* p_c_real, void* p_c_imag, void* p_workspace, - index_t MRaw, - index_t NRaw, - index_t KRaw, + index_t M, + index_t N, + index_t K, index_t StrideA, index_t StrideB, index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, index_t /* KBatch */ = 1) override { return std::make_unique(static_cast(p_a_real), @@ -893,15 +517,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle static_cast(p_c_real), static_cast(p_c_imag), static_cast(p_workspace), - MRaw, - NRaw, - KRaw, + M, + N, + K, StrideA, StrideB, - StrideC, - a_element_op, - b_element_op, - c_element_op); + StrideC); } // polymorphic @@ -930,16 +551,22 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle return str.str(); } - std::size_t GetWorkspaceSize(index_t MRaw, - index_t NRaw, - [[maybe_unused]] index_t KRaw, + static std::size_t GetCElementSpaceSize(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N( + M, GridwiseGemm::CalculateMPadded(M), N, GridwiseGemm::CalculateNPadded(N), StrideC); + + return c_grid_desc_m_n.GetElementSpaceSize(); + } + + std::size_t GetWorkspaceSize(index_t M, + index_t N, + [[maybe_unused]] index_t K, [[maybe_unused]] index_t StrideA, [[maybe_unused]] index_t StrideB, index_t StrideC) override { - const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC); - - return 2 * sizeof(CDataType) * c_grid_desc_m_n.GetElementSpaceSize(); + return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC); } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index 7a4c8bf2671a4c0ade6171c1f9a4acfd2b5af762..dd57c90898e7c908cb6952e74a2a0decca28a5b5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -52,7 +52,8 @@ __global__ void e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -581,7 +582,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || + ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || + ck::get_device_name() == "gfx942")) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index b65afce8df5834d14f4ee9b3b19730e11db43bee..2ab09ba5c833a489ea22a3f3db06134037bdcedd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp index ea3020663a88c1ce565beab17313e72928f50461..eb4db6f8cb098915654ecf32e9ea3c8bc1939ef0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -379,9 +379,6 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K AccDataType, CDataType, InMemoryDataOperationEnum::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, @@ -428,20 +425,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, - std::vector input_right_pads, - ck::index_t M01, - ck::index_t N01, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) + std::vector input_right_pads) : p_a_grid_{p_out_grid}, p_b_grid_{p_wei_grid}, p_c_grid_{p_in_grid}, - M01_{M01}, - N01_{N01}, - a_element_op_{out_element_op}, - b_element_op_{wei_element_op}, - c_element_op_{in_element_op}, Conv_N_{N}, Conv_K_{K}, Conv_C_{C}, @@ -495,18 +482,6 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); c_grid_desc_m_n_container_.push_back(descs[I2]); - - auto block_2_ctile_map = - GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01, N01); - - if(GridwiseGemm::CheckValidity( - descs[I0], descs[I1], descs[I2], block_2_ctile_map)) - { - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); - - block_2_ctile_map_container_.push_back(block_2_ctile_map); - } } } } @@ -517,14 +492,6 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::vector a_grid_desc_k0_m_k1_container_; std::vector b_grid_desc_k0_n_k1_container_; std::vector c_grid_desc_m_n_container_; - std::vector - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_; - std::vector block_2_ctile_map_container_; - index_t M01_; - index_t N01_; - OutElementwiseOperation a_element_op_; - WeiElementwiseOperation b_element_op_; - InElementwiseOperation c_element_op_; // for checking IsSupportedArgument() index_t Conv_N_; index_t Conv_K_; @@ -567,103 +534,68 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", " << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I0) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I1) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I2) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I3) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I4) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5) - << " ) " << std::endl; } #endif if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m_n_container_[i], - arg.block_2_ctile_map_container_[i])) + arg.c_grid_desc_m_n_container_[i])) { throw std::runtime_error( "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); } - const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize( - arg.c_grid_desc_m_n_container_[i]); + const auto [gdx, gdy, gdz] = + GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]); const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) * arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2); if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, - OutElementwiseOperation, - WeiElementwiseOperation, - InElementwiseOperation, - remove_reference_t, - true>; - - ave_time += launch_and_time_kernel( - stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_container_[i], - arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_container_[i]); + const auto kernel = + kernel_gemm_xdlops_v2r3; + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i]); } else { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, - OutElementwiseOperation, - WeiElementwiseOperation, - InElementwiseOperation, - remove_reference_t, - false>; - - ave_time += launch_and_time_kernel( - stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_container_[i], - arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_container_[i]); + const auto kernel = + kernel_gemm_xdlops_v2r3; + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i]); } } return ave_time; @@ -716,8 +648,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m_n_container_[i], - arg.block_2_ctile_map_container_[i])) + arg.c_grid_desc_m_n_container_[i])) { return false; } @@ -742,10 +673,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, - std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) + std::vector input_right_pads) { return Argument{p_in_grid, p_wei_grid, @@ -759,12 +687,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads, - 1, - 1, - in_element_op, - wei_element_op, - out_element_op}; + input_right_pads}; } static auto MakeInvoker() { return Invoker{}; } @@ -783,9 +706,9 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::vector conv_filter_dilations, std::vector input_left_pads, std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) override + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation) override { return std::make_unique(static_cast(p_in_grid), static_cast(p_wei_grid), @@ -799,12 +722,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads, - 1, - 1, - in_element_op, - wei_element_op, - out_element_op); + input_right_pads); } std::unique_ptr MakeInvokerPointer() override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp index e7e2bf33540fd8dd05ba1062d32241d39ff371cd..bb8f5316134d8941fc24d17ed9a65d280353b9f8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp index 6c4957b9b17f68c3a5d33470582d1d67907a8e85..1bd1e553cd36022ff86ba37aee57157e0feca1e1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index 027f1a1954da565721a1bf072f58e8feddcf29ca..de6bf27fec37511ae40d98e6dda7b4352e1fa139 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp index 6278220c2f9dee047f632acef0e0d2f911156fa6..88615bba3134311b4120f4aa4483050df3e084a8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -329,9 +329,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K AccDataType, CDataType, InMemoryDataOperationEnum::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, @@ -378,25 +375,13 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, - std::vector input_right_pads, - ck::index_t M01, - ck::index_t N01, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) + std::vector input_right_pads) : p_a_grid_{p_in_grid}, p_b_grid_{p_wei_grid}, p_c_grid_{p_out_grid}, a_grid_desc_k0_m_k1_{}, b_grid_desc_k0_n_k1_{}, c_grid_desc_m_n_{}, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, - block_2_ctile_map_{}, - M01_{M01}, - N01_{N01}, - in_element_op_{in_element_op}, - wei_element_op_{wei_element_op}, - out_element_op_{out_element_op}, Conv_N_{N}, Conv_K_{K}, Conv_C_{C}, @@ -420,17 +405,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K a_grid_desc_k0_m_k1_ = descs[I0]; b_grid_desc_k0_n_k1_ = descs[I1]; c_grid_desc_m_n_ = descs[I2]; - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); - - if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, - b_grid_desc_k0_n_k1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); - } } // private: @@ -440,14 +414,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; CGridDesc_M_N c_grid_desc_m_n_; - typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; - InElementwiseOperation in_element_op_; - WeiElementwiseOperation wei_element_op_; - OutElementwiseOperation out_element_op_; // for checking IsSupportedArgument() index_t Conv_N_; index_t Conv_K_; @@ -479,17 +445,14 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } #endif - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_)) + if(!GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_)) { throw std::runtime_error( "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); } - const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const auto K = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); @@ -498,22 +461,18 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation, - remove_reference_t, - true>; + const auto kernel = + kernel_gemm_xdlops_v2r3; ave_time = launch_and_time_kernel(stream_config, kernel, - dim3(grid_size), + dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg.p_a_grid_, @@ -521,30 +480,22 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K arg.p_c_grid_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.in_element_op_, - arg.wei_element_op_, - arg.out_element_op_, - arg.block_2_ctile_map_); + arg.c_grid_desc_m_n_); } else { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation, - remove_reference_t, - false>; + const auto kernel = + kernel_gemm_xdlops_v2r3; ave_time = launch_and_time_kernel(stream_config, kernel, - dim3(grid_size), + dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg.p_a_grid_, @@ -552,11 +503,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K arg.p_c_grid_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.in_element_op_, - arg.wei_element_op_, - arg.out_element_op_, - arg.block_2_ctile_map_); + arg.c_grid_desc_m_n_); } return ave_time; @@ -616,10 +563,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K } // Gridwise GEMM size - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); } bool IsSupportedArgument(const BaseArgument* p_arg) override @@ -639,10 +584,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, - std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) + std::vector input_right_pads) { return Argument{p_in_grid, p_wei_grid, @@ -656,12 +598,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads, - 1, - 1, - in_element_op, - wei_element_op, - out_element_op}; + input_right_pads}; } static auto MakeInvoker() { return Invoker{}; } @@ -680,9 +617,9 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::vector conv_filter_dilations, std::vector input_left_pads, std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) override + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation) override { return std::make_unique(static_cast(p_in_grid), static_cast(p_wei_grid), @@ -696,12 +633,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads, - 1, - 1, - in_element_op, - wei_element_op, - out_element_op); + input_right_pads); } std::unique_ptr MakeInvokerPointer() override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp index f69d8f18ae036b8e7cd060a80a0cca12bfc2a050..cd89f3232cb94e9ce9a68c06587758ac9574816f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef DEVICE_CONV3D_FWD_NAIVE_HPP #define DEVICE_CONV3D_FWD_NAIVE_HPP diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp index 31c761e09f80bc8fb5aeaec6bf268c00dbefc094..bc6582835e051c8335b7c4e12c5492b436a34392 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef DEVICE_CONV3D_FWD_XDL_HPP #define DEVICE_CONV3D_FWD_XDL_HPP @@ -55,7 +55,8 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / num_batches); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp index 2a2edc29bac75fa7c7e5db352fcaa934754cdae4..3178f73f4b530e67b399cd0ae71064a4b66fa837 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -1393,7 +1393,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl static bool IsSupportedArgument(const Argument& arg) { // check device - if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030")) + if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || + ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || + ck::get_device_name() == "gfx1102")) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp index 1fd4b76cec4cad5fd23b6a64166d210abdcb902a..77ad61d7ec5c446eed5c2171e36f2af2490d51a2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -980,9 +980,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl AccDataType, CDataType, InMemoryDataOperationEnum::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, @@ -1029,20 +1026,10 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, - std::vector input_right_pads, - ck::index_t M01, - ck::index_t N01, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) + std::vector input_right_pads) : p_a_grid_{p_out_grid}, p_b_grid_{p_wei_grid}, p_c_grid_{p_in_grid}, - M01_{M01}, - N01_{N01}, - a_element_op_{out_element_op}, - b_element_op_{wei_element_op}, - c_element_op_{in_element_op}, Conv_N_{N}, Conv_K_{K}, Conv_C_{C}, @@ -1092,17 +1079,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); c_grid_desc_m_n_container_.push_back(descs[I2]); - - auto block_2_ctile_map = - GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_); - - if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], block_2_ctile_map)) - { - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); - - block_2_ctile_map_container_.push_back(block_2_ctile_map); - } } } template ::type = false> @@ -1150,18 +1126,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); c_grid_desc_m_n_container_.push_back(descs[I2]); - - auto block_2_ctile_map = - GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_); - - if(GridwiseGemm::CheckValidity( - descs[I0], descs[I1], descs[I2], block_2_ctile_map)) - { - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); - - block_2_ctile_map_container_.push_back(block_2_ctile_map); - } } } } @@ -1218,19 +1182,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); c_grid_desc_m_n_container_.push_back(descs[I2]); - - auto block_2_ctile_map = - GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_); - - if(GridwiseGemm::CheckValidity( - descs[I0], descs[I1], descs[I2], block_2_ctile_map)) - { - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2( - descs[I2])); - - block_2_ctile_map_container_.push_back(block_2_ctile_map); - } } } } @@ -1242,11 +1193,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl std::vector a_grid_desc_k0_m_k1_container_; std::vector b_grid_desc_k0_n_k1_container_; std::vector c_grid_desc_m_n_container_; - std::vector - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_; - std::vector block_2_ctile_map_container_; - index_t M01_; - index_t N01_; OutElementwiseOperation a_element_op_; WeiElementwiseOperation b_element_op_; InElementwiseOperation c_element_op_; @@ -1276,123 +1222,84 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl { #if DEBUG_LOG { - std::cout << "arg.a_grid_desc_k0_m_k1_container_{" + std::cout << "arg.a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", " << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}" << std::endl; - std::cout << "arg.b_grid_desc_k0_n_k1_container_{" + std::cout << "arg.b_grid_desc_k0_n_k1{" << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", " << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", " << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}" << std::endl; - std::cout << "arg.c_grid_desc_m_n_container_{ " + std::cout << "arg.c_grid_desc_m_n{" << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", " << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I0) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I1) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I2) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I3) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I4) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I6) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I7) - << " ) " << std::endl; } #endif if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m_n_container_[i], - arg.block_2_ctile_map_container_[i])) + arg.c_grid_desc_m_n_container_[i])) { throw std::runtime_error( - "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); } - const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize( - arg.c_grid_desc_m_n_container_[i]); + const auto [gdx, gdy, gdz] = + GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]); const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) * arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2); if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, - OutElementwiseOperation, - WeiElementwiseOperation, - InElementwiseOperation, - remove_reference_t, - true>; - - ave_time += launch_and_time_kernel( - stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_container_[i], - arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_container_[i]); + const auto kernel = + kernel_gemm_xdlops_v2r3; + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i]); } else { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, - OutElementwiseOperation, - WeiElementwiseOperation, - InElementwiseOperation, - remove_reference_t, - false>; - - ave_time += launch_and_time_kernel( - stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_container_[i], - arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_container_[i]); + const auto kernel = + kernel_gemm_xdlops_v2r3; + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i]); } } return ave_time; @@ -1446,8 +1353,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m_n_container_[i], - arg.block_2_ctile_map_container_[i])) + arg.c_grid_desc_m_n_container_[i])) { return false; } @@ -1472,10 +1378,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, - std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) + std::vector input_right_pads) { return Argument{p_in_grid, p_wei_grid, @@ -1489,12 +1392,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads, - 1, - 1, - in_element_op, - wei_element_op, - out_element_op}; + input_right_pads}; } static auto MakeInvoker() { return Invoker{}; } @@ -1513,9 +1411,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl std::vector conv_filter_dilations, std::vector input_left_pads, std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) override + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation) override { return std::make_unique(static_cast(p_in_grid), static_cast(p_wei_grid), @@ -1529,12 +1427,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads, - 1, - 1, - in_element_op, - wei_element_op, - out_element_op); + input_right_pads); } std::unique_ptr MakeInvokerPointer() override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp index 83ed6198bd3c0921988d0d867a87a617bc2c76a1..02ef29e32ddc2bbd985fb4f46c6eb3825c00c11a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -13,6 +13,7 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/stream_utility.hpp" namespace ck { namespace tensor_operation { @@ -171,10 +172,7 @@ struct DeviceElementwise2dImpl : public DeviceElementwise 0, ""); static_assert(NumDim_n > 0, ""); @@ -192,34 +190,10 @@ struct DeviceElementwise2dImpl : public DeviceElementwise(out_dev_buffers[I.value]); }, Number{}); - - in_grid_2d_desc_tuple_ = generate_tuple( - [&](auto I) { - return MakeDescriptor_MN(lengths, - inStridesArray[I.value], - gridSize_, - blockSize_, - num_threads_m_, - num_threads_n_); - }, - Number{}); - - out_grid_2d_desc_tuple_ = generate_tuple( - [&](auto I) { - return MakeDescriptor_MN(lengths, - outStridesArray[I.value], - gridSize_, - blockSize_, - num_threads_m_, - num_threads_n_); - }, - Number{}); } InDataTypePointerTuple in_dev_buffers_; OutDataTypePointerTuple out_dev_buffers_; - InGrid2dDescTuple in_grid_2d_desc_tuple_; - OutGrid2dDescTuple out_grid_2d_desc_tuple_; std::array lengths_; std::array, NumInput> inStridesArray_; @@ -227,15 +201,38 @@ struct DeviceElementwise2dImpl : public DeviceElementwise{}); + + auto out_grid_2d_desc_tuple = generate_tuple( + [&](auto I) { + return MakeDescriptor_MN(arg.lengths_, + arg.outStridesArray_[I.value], + gridSize, + arg.blockSize_, + num_threads_m, + num_threads_n); + }, + Number{}); + const auto kernel = kernel_elementwise_2d(out_dev_buffers[I.value]); }, Number{}); - - in_grid_1d_desc_tuple_ = generate_tuple( - [&](auto I) { - return MakeDescriptor_M( - lengths, inStridesArray[I.value], gridSize_, blockSize_); - }, - Number{}); - - out_grid_1d_desc_tuple_ = generate_tuple( - [&](auto I) { - return MakeDescriptor_M( - lengths, outStridesArray[I.value], gridSize_, blockSize_); - }, - Number{}); } InDataTypePointerTuple in_dev_buffers_; OutDataTypePointerTuple out_dev_buffers_; - InGrid1dDescTuple in_grid_1d_desc_tuple_; - OutGrid1dDescTuple out_grid_1d_desc_tuple_; std::array lengths_; std::array, NumInput> inStridesArray_; @@ -187,13 +171,28 @@ struct DeviceElementwiseImpl ElementwiseOperation elementwise_op_; index_t blockSize_; - index_t gridSize_; }; struct Invoker : public BaseInvoker { float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { + index_t gridSize = getAvailableComputeUnitCount(stream_config); + + auto in_grid_1d_desc_tuple = generate_tuple( + [&](auto I) { + return MakeDescriptor_M( + arg.lengths_, arg.inStridesArray_[I.value], gridSize, arg.blockSize_); + }, + Number{}); + + auto out_grid_1d_desc_tuple = generate_tuple( + [&](auto I) { + return MakeDescriptor_M( + arg.lengths_, arg.outStridesArray_[I.value], gridSize, arg.blockSize_); + }, + Number{}); + const auto kernel = kernel_elementwise_1d -#include - -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" - -namespace ck { - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_bias_e_permute(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatDsPointer p_ds_grid, - FloatE* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2ETileMap block_2_etile_map) -{ -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); -#else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_ds_grid; - ignore = p_e_grid; - ignore = a_element_op; - ignore = b_element_op; - ignore = cde_element_op; - ignore = a_grid_desc_ak0_m_ak1; - ignore = b_grid_desc_bk0_n_bk1; - ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = block_2_etile_map; -#endif -} - -} // namespace ck - -namespace ck { -namespace tensor_operation { -namespace device { - -// input : A[M, K], or A[K, N] -// input : B[K, N], or A[N, K] -// input : D0[M, N], D1[M, N], ... -// output : E[M, N] -// C = a_op(A) * b_op(B) -// E = cde_op(C, D0, D1, ...) -template -struct DeviceGemmBiasEPermute_Xdl : public DeviceGemmBiasCPermute -{ - using DeviceOp = DeviceGemmBiasEPermute_Xdl; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - static constexpr auto matrix_padder = - MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; - - static constexpr index_t NumDTensor = 1; - - static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) - { - const auto a_grid_desc_mraw_kraw = [&]() { - if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(StrideA, I1)); - } - else if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(I1, StrideA)); - } - }(); - - return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); - } - - static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) - { - const auto b_grid_desc_nraw_kraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(I1, StrideB)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(StrideB, I1)); - } - }(); - - return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); - } - - static auto MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1 d_e_grid_desc) - { - index_t M0 = d_e_grid_desc.M0_; - index_t M1 = d_e_grid_desc.M1_; - index_t M2 = d_e_grid_desc.M2_; - index_t N0 = d_e_grid_desc.N0_; - index_t N1 = d_e_grid_desc.N1_; - - index_t stride_M0 = d_e_grid_desc.stride_M0_; - index_t stride_M1 = d_e_grid_desc.stride_M1_; - index_t stride_M2 = d_e_grid_desc.stride_M2_; - index_t stride_N0 = d_e_grid_desc.stride_N0_; - index_t stride_N1 = d_e_grid_desc.stride_N1_; - - const auto e_grid_desc_mraw_nraw = [&]() { - const auto e_grid_desc_m0_m1_m2_n0_n1 = make_naive_tensor_descriptor( - make_tuple(M0, M1, M2, N0, N1), - make_tuple(stride_M0, stride_M1, stride_M2, stride_N0, stride_N1)); - - return transform_tensor_descriptor( - e_grid_desc_m0_m1_m2_n0_n1, - make_tuple(make_merge_transform(make_tuple(M0, M1, M2)), - make_merge_transform(make_tuple(N0, N1))), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - }(); - - return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); - } - - using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); - using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); - using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1{})); - - using DsGridDesc_M_N = Tuple; - - // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< - ADataType, // TODO: distinguish A/B datatype - AccDataType, - CShuffleDataType, - ck::Tuple, - EDataType, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - InMemoryDataOperationEnum::Set, - NumGemmKPrefetchStage, - BlockSize, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEBlockTransferScalarPerVector_NPerBlock, - LoopSched>; - - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; - - using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; - - // Argument - struct Argument : public BaseArgument - { - Argument(const void* p_a_grid, - const void* p_b_grid, - const void* p_d_grid, - void* p_e_grid, - index_t MRaw, - index_t NRaw, - index_t KRaw, - index_t StrideA, - index_t StrideB, - DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc, - DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - : p_a_grid_{static_cast(p_a_grid)}, - p_b_grid_{static_cast(p_b_grid)}, - p_ds_grid_{}, - p_e_grid_{static_cast(p_e_grid)}, - a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)}, - b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, - ds_grid_desc_m_n_{}, - e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_grid_desc)}, - a_grid_desc_ak0_m_ak1_{ - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, - b_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, - ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, - e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - cde_element_op_{cde_element_op} - { - - if(MRaw != d_grid_desc.M0_ * d_grid_desc.M1_ * d_grid_desc.M2_) - { - throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); - } - - if(NRaw != d_grid_desc.N0_ * d_grid_desc.N1_) - { - throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); - } - - // populate pointer, desc for Ds - // D pointer - p_ds_grid_(I0) = static_cast(p_d_grid); - - // D desc - ds_grid_desc_m_n_(I0) = DeviceOp::MakeEGridDescriptor_M_N(d_grid_desc); - - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, - b_grid_desc_n_k_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) - { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - - ds_grid_desc_mblock_mperblock_nblock_nperblock_(I0) = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_[I0]); - } - } - - // private: - // pointers - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; - EDataType* p_e_grid_; - - // tensor descriptors for problem definiton - AGridDesc_M_K a_grid_desc_m_k_; - BGridDesc_N_K b_grid_desc_n_k_; - DsGridDesc_M_N ds_grid_desc_m_n_; - EGridDesc_M_N e_grid_desc_m_n_; - - // tensor descriptors for block/thread-wise copy - AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; - BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; - typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_; - - // block-to-e-tile map - Block2ETileMap block_2_etile_map_; - - // element-wise op - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation cde_element_op_; - }; - - // Invoker - struct Invoker : public BaseInvoker - { - using Argument = DeviceOp::Argument; - - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_)) - { - throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); - } - - const index_t grid_size = - arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); - - const auto K = - arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); - - auto launch_kernel = [&](auto has_main_k_block_loop) { - constexpr bool has_main_loop = has_main_k_block_loop.value; - - const auto kernel = kernel_gemm_bias_e_permute< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - typename GridwiseGemm::DsGridPointer, - EDataType, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, - typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, - typename GridwiseGemm::DefaultBlock2ETileMap, - has_main_loop>; - - return launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_ds_grid_, - arg.p_e_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_etile_map_); - }; - - if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) - { - return launch_kernel(integral_constant{}); - } - else - { - return launch_kernel(integral_constant{}); - } - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - static bool IsSupportedArgument(const Argument& arg) - { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) - { - return false; - } - - return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); - } - - // polymorphic - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - return IsSupportedArgument(*dynamic_cast(p_arg)); - } - - static auto MakeArgument(const void* p_a, - const void* p_b, - const void* p_d, - void* p_e, - index_t MRaw, - index_t NRaw, - index_t KRaw, - index_t StrideA, - index_t StrideB, - DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc, - DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - { - return Argument{p_a, - p_b, - p_d, - p_e, - MRaw, - NRaw, - KRaw, - StrideA, - StrideB, - d_grid_desc, - e_grid_desc, - a_element_op, - b_element_op, - cde_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - // polymorphic - std::unique_ptr - MakeArgumentPointer(const void* p_a, - const void* p_b, - const void* p_d, - void* p_e, - index_t MRaw, - index_t NRaw, - index_t KRaw, - index_t StrideA, - index_t StrideB, - DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc, - DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) override - { - return std::make_unique(p_a, - p_b, - p_d, - p_e, - MRaw, - NRaw, - KRaw, - StrideA, - StrideB, - d_grid_desc, - e_grid_desc, - a_element_op, - b_element_op, - cde_element_op); - } - - // polymorphic - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(Invoker{}); - } - - // polymorphic - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceGemmBiasEPermute_Xdl" - << "<" - << BlockSize << ", " - << MPerBlock << ", " - << NPerBlock << ", " - << KPerBlock << ", " - << AK1 << ", " - << BK1 << ", " - << K1 << ", " - << MPerXDL << ", " - << NPerXDL << ", " - << MXdlPerWave << ", " - << NXdlPerWave << ", " - << ABlockTransferSrcScalarPerVector << ", " - << ABlockTransferDstScalarPerVector_K1 << ", " - << BBlockTransferSrcScalarPerVector << ", " - << BBlockTransferDstScalarPerVector_K1 << ", " - << CShuffleMXdlPerWavePerShuffle << ", " - << CShuffleNXdlPerWavePerShuffle << ", " - << CBlockTransferScalarPerVector_NWaveNPerXdl - << ">"; - // clang-format on - - return str.str(); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp index af1989fc4a2cc41783806bbe3c2f67718e45e524..13e9f96910fadeb0fdda4ef723df3f260bc7ee7b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -485,7 +485,9 @@ struct DeviceGemmDl : public DeviceGemm( @@ -806,7 +807,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle // workspace for welford intermediate mean workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 64; - // workspace for welford intermediate mean + // workspace for welford intermediate variance workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 64; // workspace for welford intermediate count @@ -854,7 +855,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || + ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || + ck::get_device_name() == "gfx942")) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index f1185357a49eb6feb4868ac8fa9f4e387fb60eae..3a0be837b7f3aa6683b7ade07a3d32f71f1cc0b5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -60,7 +60,8 @@ __global__ void const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -554,7 +555,9 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || + ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || + ck::get_device_name() == "gfx942")) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp index b59bb2b3038e06ab698cc930aeacd0cd3a25c5cd..44b3518e2c950e15ec3d5cfd240c443029987622 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -273,7 +273,10 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD && ABlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector laod of B + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of Ds + // only support RowMajor for now + bool all_valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + if constexpr(!is_same_v) + { + all_valid = false; + } + }); + + if(!all_valid) + { + return false; + } + + // check vector store of E + // only support RowMajor for now + if constexpr(is_same_v) + { + if(arg.NRaw_ % CDEShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + } return GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index 36e81051221cb6832d8048975fd6012dc9742d31..d4ecea7bb7a815b9f5f98e71b4350c8f1ca7c43b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -51,7 +51,8 @@ __global__ void e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -490,7 +491,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD && ABlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector laod of B + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector store of C + // only support RowMajor for now + if constexpr(is_same_v) + { + if(arg.NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + } + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp index 48fb646375e0de03ea4973492b791e7e93205847..ebbb1d38c000ba44ce53f9a8110114c41a299b14 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -75,132 +75,20 @@ struct DeviceGemmXdl : public DeviceGemm{}; - static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) - { - const index_t K0 = K / K1; - - const auto a_grid_desc_m_k = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); - } - }(); - - if constexpr(GemmSpec == GemmSpecialization::MNPadding) - { - const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; - - return transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_right_pad_transform(M, PadM)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - else - { - return transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - } - - static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) - { - const index_t K0 = K / K1; - - const auto b_grid_desc_k_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); - } - }(); - - if constexpr(GemmSpec == GemmSpecialization::MNPadding) - { - const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; - - return transform_tensor_descriptor( - b_grid_desc_k_n, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_right_pad_transform(N, PadN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - else - { - return transform_tensor_descriptor( - b_grid_desc_k_n, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - } - - static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) - { - const auto c_grid_desc_m_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); - } - }(); - - if constexpr(GemmSpec == GemmSpecialization::MNPadding) - { - const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; - const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; - - return transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else - { - - return transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - } - - using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); - using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); - using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); - // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext< BlockSize, ADataType, // TODO: distinguish A/B datatype AccDataType, CDataType, InMemoryDataOperationEnum::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, + ALayout, + BLayout, + CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, + GemmSpec, MPerBlock, NPerBlock, K0PerBlock, @@ -232,173 +120,41 @@ struct DeviceGemmXdl : public DeviceGemm; - // Argument - struct Argument : public BaseArgument - { - Argument(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - index_t M01, - index_t N01, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - : p_a_grid_{p_a_grid}, - p_b_grid_{p_b_grid}, - p_c_grid_{p_c_grid}, - a_grid_desc_k0_m_k1_{}, - b_grid_desc_k0_n_k1_{}, - c_grid_desc_m_n_{}, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, - block_2_ctile_map_{}, - M01_{M01}, - N01_{N01}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op}, - kraw_{K} - { - a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); - b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); - c_grid_desc_m_n_ = DeviceGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC); - - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); - - if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, - b_grid_desc_k0_n_k1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); - } - } - - // private: - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - CDataType* p_c_grid_; - AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; - CGridDesc_M_N c_grid_desc_m_n_; - typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CElementwiseOperation c_element_op_; - index_t kraw_; - }; + using Argument = typename GridwiseGemm::Argument; // Invoker struct Invoker : public BaseInvoker { - using Argument = DeviceGemmXdl::Argument; - - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) { -#if DEBUG_LOG + if(stream_config.log_level_ > 0) { - std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) - << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) - << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " - << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + karg.Print(); } -#endif - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_)) + if(!GridwiseGemm::CheckValidity(karg)) { throw std::runtime_error( - "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext has invalid setting"); } - const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); - - const auto K = - arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N); float ave_time = 0; - if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K)) { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - true>; - - ave_time = launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + const auto kernel = kernel_gemm_xdlops_v2r3; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); } else { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - false>; - - ave_time = launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + const auto kernel = kernel_gemm_xdlops_v2r3; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); } return ave_time; @@ -418,7 +174,7 @@ struct DeviceGemmXdl : public DeviceGemm || is_same_v || is_same_v || is_same_v)) @@ -441,15 +197,12 @@ struct DeviceGemmXdl : public DeviceGemm(static_cast(p_a), static_cast(p_b), @@ -511,12 +251,7 @@ struct DeviceGemmXdl : public DeviceGemm{}; static constexpr auto I2 = Number<2>{}; - static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) - { - const auto a_grid_desc_mraw_kraw = [&]() { - if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(StrideA, I1)); - } - else if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(I1, StrideA)); - } - }(); - - const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; - const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; - - const auto MPad = M - MRaw; - const auto KPad = K - KRaw; - - if constexpr(GemmSpec == GemmSpecialization::MKPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad both M and K - assert(K % AK1 == 0); - - const auto AK0 = K / AK1; - - const auto a_grid_desc_m_k = - transform_tensor_descriptor(a_grid_desc_mraw_kraw, - make_tuple(make_right_pad_transform(MRaw, MPad), - make_right_pad_transform(KRaw, KPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto a_grid_desc_ak0_m_ak1 = - transform_tensor_descriptor(a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - else if constexpr(GemmSpec == GemmSpecialization::MPadding || - GemmSpec == GemmSpecialization::MNPadding) - { - // pad M, but not K - assert(KRaw % AK1 == 0); - - const auto AK0 = KRaw / AK1; - - const auto a_grid_desc_ak0_m_ak1 = - transform_tensor_descriptor(a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), - make_right_pad_transform(MRaw, MPad)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - else if constexpr(GemmSpec == GemmSpecialization::KPadding || - GemmSpec == GemmSpecialization::NKPadding) - { - // pad K, but not M - assert(K % AK1 == 0); - - const auto AK0 = K / AK1; - - const auto a_grid_desc_m_k = transform_tensor_descriptor( - a_grid_desc_mraw_kraw, - make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto a_grid_desc_ak0_m_ak1 = - transform_tensor_descriptor(a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), - make_pass_through_transform(MRaw)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - else - { - // not pad M or K - assert(KRaw % AK1 == 0); - - const auto AK0 = KRaw / AK1; - - const auto a_grid_desc_ak0_m_ak1 = - transform_tensor_descriptor(a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), - make_pass_through_transform(MRaw)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - } - - static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) - { - const auto b_grid_desc_nraw_kraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(I1, StrideB)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(StrideB, I1)); - } - }(); - - const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; - const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; - - const auto NPad = N - NRaw; - const auto KPad = K - KRaw; - - if constexpr(GemmSpec == GemmSpecialization::NKPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad both N and K - assert(K % BK1 == 0); - - const auto BK0 = K / BK1; - - const auto b_grid_desc_n_k = - transform_tensor_descriptor(b_grid_desc_nraw_kraw, - make_tuple(make_right_pad_transform(NRaw, NPad), - make_right_pad_transform(KRaw, KPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto b_grid_desc_bk0_n_bk1 = - transform_tensor_descriptor(b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else if constexpr(GemmSpec == GemmSpecialization::NPadding || - GemmSpec == GemmSpecialization::MNPadding) - { - // pad N, but not K - assert(KRaw % BK1 == 0); - - const auto BK0 = KRaw / BK1; - - const auto b_grid_desc_bk0_n_bk1 = - transform_tensor_descriptor(b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_right_pad_transform(NRaw, NPad)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else if constexpr(GemmSpec == GemmSpecialization::KPadding || - GemmSpec == GemmSpecialization::MKPadding) - { - // pad K, but not N - assert(K % BK1 == 0); - - const auto BK0 = K / BK1; - - const auto b_grid_desc_n_k = transform_tensor_descriptor( - b_grid_desc_nraw_kraw, - make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto b_grid_desc_bk0_n_bk1 = - transform_tensor_descriptor(b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(NRaw)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else - { - // not pad N or K - assert(KRaw % BK1 == 0); - - const auto BK0 = KRaw / BK1; - - const auto b_grid_desc_bk0_n_bk1 = - transform_tensor_descriptor(b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(NRaw)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - } - - static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) - { - const auto c_grid_desc_mraw_nraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), - make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), - make_tuple(I1, StrideC)); - } - }(); - - const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; - const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; - - const auto MPad = M - MRaw; - const auto NPad = N - NRaw; - - if constexpr(GemmSpec == GemmSpecialization::MNPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad M and N - return transform_tensor_descriptor(c_grid_desc_mraw_nraw, - make_tuple(make_right_pad_transform(MRaw, MPad), - make_right_pad_transform(NRaw, NPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else if constexpr(GemmSpec == GemmSpecialization::MPadding || - GemmSpec == GemmSpecialization::MKPadding) - { - // pad M, but not N - return transform_tensor_descriptor( - c_grid_desc_mraw_nraw, - make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else if constexpr(GemmSpec == GemmSpecialization::NPadding || - GemmSpec == GemmSpecialization::NKPadding) - { - // pad N, but not M - return transform_tensor_descriptor( - c_grid_desc_mraw_nraw, - make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else - { - // not pad M or N - return c_grid_desc_mraw_nraw; - } - } - - using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); - using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); - using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); - // GridwiseGemm using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ALayout, + BLayout, + CLayout, ADataType, // TODO: distinguish A/B datatype GemmAccDataType, CShuffleDataType, @@ -359,10 +94,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm; - // Argument - struct Argument : public BaseArgument - { - Argument(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - index_t MRaw, - index_t NRaw, - index_t KRaw, - index_t StrideA, - index_t StrideB, - index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - : p_a_grid_{p_a_grid}, - p_b_grid_{p_b_grid}, - p_c_grid_{p_c_grid}, - a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, - b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, - c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, - c_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op}, - kraw_{KRaw} - { - if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, - b_grid_desc_bk0_n_bk1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n_); - } - } - - // private: - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - CDataType* p_c_grid_; - AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; - BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; - CGridDesc_M_N c_grid_desc_m_n_; - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CElementwiseOperation c_element_op_; - index_t kraw_; - }; + using Argument = typename GridwiseGemm::Argument; // Invoker struct Invoker : public BaseInvoker { - using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { -#if DEBUG_LOG + if(stream_config.log_level_ > 0) { - std::cout << "arg.a_grid_desc_ak0_m_ak1_{" - << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " - << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " - << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_bk0_n_bk1_{" - << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " - << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " - << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " - << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + arg.Print(); } -#endif - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_)) + if(!GridwiseGemm::CheckValidity(arg)) { throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } - const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); - const auto K = - arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N); + + const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1; float ave_time = 0; if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { - const auto kernel = kernel_gemm_xdl_cshuffle_v1< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, - typename GridwiseGemm::DefaultBlock2CTileMap, - true>; - - ave_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); + const auto kernel = kernel_gemm_xdl_cshuffle_v1; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); } else { - const auto kernel = kernel_gemm_xdl_cshuffle_v1< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, - typename GridwiseGemm::DefaultBlock2CTileMap, - false>; - ave_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); + const auto kernel = kernel_gemm_xdl_cshuffle_v1; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); } return ave_time; @@ -574,24 +188,22 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm MakeArgumentPointer(const void* p_a, const void* p_b, void* p_c, - index_t MRaw, - index_t NRaw, - index_t KRaw, + index_t M, + index_t N, + index_t K, index_t StrideA, index_t StrideB, index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) override + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override { return std::make_unique(static_cast(p_a), static_cast(p_b), static_cast(p_c), - MRaw, - NRaw, - KRaw, + M, + N, + K, StrideA, StrideB, - StrideC, - a_element_op, - b_element_op, - c_element_op); + StrideC); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp index 2a3e74a4cb0f25e04d886cf3e42066c5c96a348e..a2403f109c50da0c21bb6272798b5b91bb7ed97b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -648,7 +648,9 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || + ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || + ck::get_device_name() == "gfx942")) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp index 36b01f677f080e4f2f515738af6171d6ed4dbbcd..ef5e267819c7fd8ca3868f689fd7bf16dd50fbb8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp index 0d2aeaeb7e6a35bdfe05b8e2ed2005d70b75f55c..6c07a5c56ce94a06f5602c722af100fe0a8460d9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -73,157 +73,24 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK{}; static constexpr auto I3 = Number<3>{}; - static constexpr auto K1Number = Number{}; + // TODO: should be exposed as Tparams. + static constexpr index_t NumGemmKPrefetchStage = 1; + static constexpr LoopScheduler LoopSched = make_default_loop_scheduler(); + static constexpr PipelineVersion PipelineVer = PipelineVersion::v1; - static auto - MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, index_t K, index_t StrideA, int KBatch, int KPad) - { - assert(KPad % (K1 * KBatch) == 0); - - const index_t K0 = KPad / (K1 * KBatch); - - const auto a_grid_desc_m_k = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); - } - }(); - - const auto a_grid_desc_m_kpad = transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - if constexpr(GemmSpec == GemmSpecialization::MNPadding) - { - const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; - return transform_tensor_descriptor( - a_grid_desc_m_kpad, - make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), - make_right_pad_transform(M, PadM)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - } - else - { - return transform_tensor_descriptor( - a_grid_desc_m_kpad, - make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - } - } - - static auto - MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, index_t N, index_t StrideB, int KBatch, int KPad) - { - assert(KPad % (K1 * KBatch) == 0); - - const index_t K0 = KPad / (K1 * KBatch); - - const auto b_grid_desc_k_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); - } - }(); - - const auto b_grid_desc_kpad_n = transform_tensor_descriptor( - b_grid_desc_k_n, - make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - if constexpr(GemmSpec == GemmSpecialization::MNPadding) - { - const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; - return transform_tensor_descriptor( - b_grid_desc_kpad_n, - make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), - make_right_pad_transform(N, PadN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - } - else - { - return transform_tensor_descriptor( - b_grid_desc_kpad_n, - make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), - make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - } - } - - static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) - { - const auto c_grid_desc_m_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); - } - }(); - - if constexpr(GemmSpec == GemmSpecialization::MNPadding) - { - const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; - const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; - - return transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else - { - - return transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - } - - static auto GetKPad(index_t K, index_t KBatch) - { - const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; - const index_t KPad = KBatch * K0 * K1; - return KPad; - } - - using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_KBatch_K0_M_K1(1, 1, 1, 1, 1)); - using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1(1, 1, 1, 1, 1)); - using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); - - // GridwiseGemm using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, // TODO: distinguish A/B datatype AccDataType, CDataType, - InMemoryDataOperationEnum::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, + ALayout, + BLayout, + CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, + GemmSpec, + NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, @@ -251,238 +118,72 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + LoopSched, + PipelineVer>; - // GridwiseGemm - using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< - BlockSize, - ADataType, // TODO: distinguish A/B datatype - AccDataType, - CDataType, - InMemoryDataOperationEnum::AtomicAdd, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - MPerBlock, - NPerBlock, - K0PerBlock, - MPerXDL, - NPerXDL, - K1, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - false, // AThreadTransferSrcResetCoordinateAfterRun, - ABlockLdsAddExtraM, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - false, // BThreadTransferSrcResetCoordinateAfterRun, - BBlockLdsAddExtraN, - CShuffleMRepeatPerShuffle, - CShuffleNRepeatPerShuffle, - CBlockTransferScalarPerVector_NWaveNPerXDL, - CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; - - using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); - - using Block2CTileMap = typename GridwiseGemm::CBlockClusterAdaptor; - - // Argument - struct Argument : public BaseArgument - { - Argument(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - index_t M01, - index_t N01, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - index_t k_batch) - : p_a_grid_{p_a_grid}, - p_b_grid_{p_b_grid}, - p_c_grid_{p_c_grid}, - a_grid_desc_kbatch_k0_m_k1_{}, - b_grid_desc_kbatch_k0_n_k1_{}, - c_grid_desc_m_n_{}, - c_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_ctile_map_{}, - M01_{M01}, - N01_{N01}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op}, - k_batch_{k_batch} - { - int KPad = DeviceGemmXdlSplitKCShuffle::GetKPad(K, k_batch_); - - a_grid_desc_kbatch_k0_m_k1_ = - DeviceGemmXdlSplitKCShuffle::MakeAGridDescriptor_KBatch_K0_M_K1( - M, K, StrideA, k_batch_, KPad); - b_grid_desc_kbatch_k0_n_k1_ = - DeviceGemmXdlSplitKCShuffle::MakeBGridDescriptor_KBatch_K0_N_K1( - K, N, StrideB, k_batch_, KPad); - c_grid_desc_m_n_ = DeviceGemmXdlSplitKCShuffle::MakeCGridDescriptor_M_N(M, N, StrideC); - - block_2_ctile_map_ = - GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); - - if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, - b_grid_desc_kbatch_k0_n_k1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); - } - } - - // private: - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - CDataType* p_c_grid_; - AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; - CGridDesc_M_N c_grid_desc_m_n_; - CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; - Block2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CElementwiseOperation c_element_op_; - index_t k_batch_; - }; + using Argument = typename GridwiseGemm::Argument; + using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; // Invoker struct Invoker : public BaseInvoker { - using Argument = DeviceGemmXdlSplitKCShuffle::Argument; - void Print(const Argument& arg) - { - std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " - << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; - } + void Print(const Argument& karg) { karg.Print(); } - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { - Print(arg); + Print(karg); } - const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); + const auto kbatch = karg.k_batch; - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_)) + if(!GridwiseGemm::CheckValidity(karg)) { throw std::runtime_error( - "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid setting"); + "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid " + "setting"); } - const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); - - const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const auto b2c_map = DefaultBlock2CTileMap{}; + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch); + const auto K0 = karg.K0; const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); float ave_time = 0; const auto Run = [&](const auto& kernel) { - hipGetErrorString(hipMemset( - arg.p_c_grid_, - 0, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * - sizeof(CDataType))); - - ave_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + if(kbatch > 1) + hipGetErrorString( + hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType))); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg, b2c_map); }; if(has_main_k0_block_loop) { if(kbatch == 1) { - const auto kernel = kernel_gemm_xdlops_v2r4r2< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - true>; + const auto kernel = + kernel_gemm_xdlops_v2r4r2_simplified; Run(kernel); } else { - const auto kernel = kernel_gemm_xdlops_v2r4r2< - GridwiseGemmAtomicAdd, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - true>; + const auto kernel = + kernel_gemm_xdlops_v2r4r2_simplified; Run(kernel); } @@ -491,37 +192,21 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK, - remove_reference_t, - remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - false>; + const auto kernel = + kernel_gemm_xdlops_v2r4r2_simplified; Run(kernel); } else { - const auto kernel = kernel_gemm_xdlops_v2r4r2< - GridwiseGemmAtomicAdd, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - false>; + const auto kernel = + kernel_gemm_xdlops_v2r4r2_simplified; Run(kernel); } @@ -544,12 +229,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK(static_cast(p_a), @@ -615,11 +296,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK"; - // clang-format on - - return str.str(); - } + std::string GetTypeString() const override { return GridwiseGemm::GetTypeString(); } }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp index 03d9e26a460962222c90f393df7115fcd4ecdf76..4b1cc65ba28ff0aff771c57b3080caede43e8d84 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -37,7 +37,8 @@ __global__ void const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); @@ -651,11 +652,12 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle } } - hipGetErrorString(hipMemcpy(arg.p_workspace_, - arg.contraction_multi_d_kernel_args_.data(), - arg.contraction_multi_d_kernel_args_.size() * - sizeof(ContractionMultiDKernelArg), - hipMemcpyHostToDevice)); + hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_, + arg.contraction_multi_d_kernel_args_.data(), + arg.contraction_multi_d_kernel_args_.size() * + sizeof(ContractionMultiDKernelArg), + hipMemcpyHostToDevice, + stream_config.stream_id_)); float ave_time = 0; @@ -703,7 +705,9 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || + ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || + ck::get_device_name() == "gfx942")) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 4cef404447f6685a0992d1201f25cecb40d986c4..95af932519c4977a7ffe773d01b33a9f30165665 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -130,7 +130,8 @@ __global__ void const Block2ETileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -458,7 +459,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 p_ds_grid_{}, p_e_grid_{static_cast(p_e)}, num_group_{a_g_n_k_wos_lengths[0]}, - num_gemm_{}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op}, @@ -507,9 +507,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const auto YTilde = ConvStrideH / GcdStrideDilationH; const auto XTilde = ConvStrideW / GcdStrideDilationW; - // number of GEMM - num_gemm_ = YTilde * XTilde; - for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) { for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) @@ -625,7 +622,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 void Print() const { - for(index_t i = 0; i < num_gemm_; i++) + for(std::size_t i = 0; i < a_grid_desc_ak0_m_ak1_container_.size(); i++) { std::cout << "a_grid_desc_ak0_m_ak1_container_" << a_grid_desc_ak0_m_ak1_container_[i] << std::endl; @@ -653,7 +650,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // tensor descriptor for problem definition index_t num_group_; - index_t num_gemm_; std::vector a_grid_desc_m_k_container_; std::vector b_grid_desc_n_k_container_; std::vector ds_grid_desc_m_n_container_; @@ -707,7 +703,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 float ave_time = 0; - for(index_t i = 0; i < arg.num_gemm_; i++) + for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i], arg.b_grid_desc_n_k_container_[i], @@ -806,7 +802,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } // vector load for A matrix from global memory to LDS - if constexpr(is_same_v) + if constexpr(is_same_v || + is_same_v) { if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0)) { @@ -861,7 +858,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } // vector store for E - if constexpr(is_same_v) + if constexpr(is_same_v || + is_same_v) { // vector store C matrix into global memory if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp index 9529744f31daffb294a27c5a96f6708c22fc3a8a..0473eaf7625fe64b99736e45f96650ee076780f8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp @@ -1027,7 +1027,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl static bool IsSupportedArgument(const Argument& arg) { // check device - if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030")) + if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || + ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || + ck::get_device_name() == "gfx1102")) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp index 34db9f2a52aafa43e80359bfa1cb4238800e9102..4f4e0d576ada5aeebcdfc42d7431164254df68f1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -78,7 +78,8 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r.hpp index 03185d5b1d29186f2b2814257eacfe90da6c5e02..face627e1fb85d0621823803eabc5409b0e73d61 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp index b60f94d141007287c0b5b4452ec11308044d9cdf..af8e0e02794def75284258a09609baefc688ebe9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -155,7 +155,8 @@ __global__ void const Block2ETileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -810,7 +811,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle return false; } } - else if(get_device_name() == "gfx90a") + else if(get_device_name() == "gfx90a" || get_device_name() == "gfx940" || + get_device_name() == "gfx941" || get_device_name() == "gfx942") { if constexpr(!(is_same_v || is_same_v || is_same_v || is_same_v)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index a1ee1902b54323482074292cff782dca6c61db78..526108b87442b4077bc9723820edcd09b7fd76f3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -840,17 +840,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle << KPerBlock << ", " << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " << K1 << ", " - << MPerXDL << ", " - << NPerXDL << ", " - << MXdlPerWave << ", " - << NXdlPerWave << ", " << ABlockTransferSrcScalarPerVector << ", " - << ABlockTransferDstScalarPerVector_K1 << ", " - << BBlockTransferSrcScalarPerVector << ", " - << BBlockTransferDstScalarPerVector_K1 << ", " - << CShuffleMXdlPerWavePerShuffle << ", " - << CShuffleNXdlPerWavePerShuffle << ", " - << CBlockTransferScalarPerVector_NWaveNPerXdl + << BBlockTransferSrcScalarPerVector << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp index 8de81285d9a73b3aa5e41fd216c7dc733b86d83d..a9294f0344efa5f3a3fc522f3f3e6544def80027 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -135,7 +135,8 @@ __global__ void const Block2ETileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -684,7 +685,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle return false; } } - else if(get_device_name() == "gfx90a") + else if(get_device_name() == "gfx90a" || get_device_name() == "gfx940" || + get_device_name() == "gfx941" || get_device_name() == "gfx942") { if constexpr(!(is_same_v || is_same_v || is_same_v || is_same_v)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp index d424f2992076cc32704631c3badbefa3c88dec2c..0190b3cee6b44af7656c34be866c9ab0855f7b0d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp @@ -1,6 +1,6 @@ #pragma once // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -39,8 +39,9 @@ __global__ void const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx1030__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__) || defined(__gfx1101__) || \ + defined(__gfx1102__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); @@ -596,10 +597,12 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm +#include + +#include "ck/ck.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + __shared__ uint8_t p_shared[shared_size]; + + const index_t block_id = get_block_1d_id(); + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_id >= gemm_desc_ptr[group_id].block_start_ && + block_id < gemm_desc_ptr[group_id].block_end_)) && + left <= right) + { + if(block_id < gemm_desc_ptr[group_id].block_start_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + GridwiseGemm::template Run( + gemm_desc_ptr[group_id].karg_, + static_cast(p_shared), + gemm_desc_ptr[group_id].block_2_ctile_map_); +#else + ignore = gemm_descs_const; + ignore = group_count; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template > && + is_same_v>, + bool> = false> +struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static_assert(KPerBlock % AK1 == 0); + static constexpr index_t K0PerBlock = KPerBlock / AK1; + + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + EDataType, + ALayout, + BLayout, + ELayout, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + NumGemmKPrefetchStage, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + AK1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferScalarPerVector_NPerBlock, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + LoopSched, + PipelineVersion::v2>; + + using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; + using Block2ETileMapKSplit = + BlockToCTileMap_KSplit_M00_N0_M01Adapt; + // Block2CTileMap configuration parameter. + static constexpr index_t B2E_M01 = 8; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; + using KernelArgument = typename GridwiseGemm::Argument; + + struct GemmTransKernelArg + { + KernelArgument karg_; + GroupedGemmBlock2ETileMap block_2_ctile_map_; + index_t block_start_, block_end_; + + GemmTransKernelArg() = default; + GemmTransKernelArg(KernelArgument&& karg, + GroupedGemmBlock2ETileMap&& b2c_map, + index_t block_start, + index_t block_end) + : karg_{karg}, + block_2_ctile_map_{b2c_map}, + block_start_{block_start}, + block_end_{block_end} + { + } + }; + + static constexpr index_t DefaultKBatch = 1; + + // Argument + struct Argument : public BaseArgument + { + + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector& p_Es, + std::vector& gemm_descs) + : Argument(p_As, p_Bs, p_Es, gemm_descs, DefaultKBatch) + { + // TODO: use occupancy api to calculate appropriate batch size. + } + + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector& p_Es, + std::vector& gemm_descs, + index_t kbatch) + : K_BATCH{kbatch} + { + grid_size_ = 0; + group_count_ = ck::type_convert(gemm_descs.size()); + + if(!(group_count_ == ck::type_convert(p_As.size()) && + group_count_ == ck::type_convert(p_Bs.size()) && + group_count_ == ck::type_convert(p_Es.size()))) + { + throw std::runtime_error("wrong! group_count_ != p_As/b/c.size"); + } + + gemm_kernel_args_.reserve(group_count_); + + skipped_group_count_ = 0; + + for(std::size_t i = 0; i < gemm_descs.size(); ++i) + { + const index_t M = gemm_descs[i].M_; + const index_t N = gemm_descs[i].N_; + const index_t K = gemm_descs[i].K_; + + if(M == 0) + { + skipped_group_count_++; + continue; + } + + const index_t stride_a = gemm_descs[i].stride_A_; + const index_t stride_b = gemm_descs[i].stride_B_; + const index_t stride_c = gemm_descs[i].stride_C_; + + const index_t m_padded = GridwiseGemm::CalculateMPadded(M); + const index_t n_padded = GridwiseGemm::CalculateNPadded(N); + const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH); + const index_t k0 = GridwiseGemm::CalculateK0(K, K_BATCH); + + const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c); + + const auto local_b2c_tile_map = + Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; + const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); + + const index_t block_start = grid_size_; + const index_t block_end = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + // block-to-e-tile map + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + auto karg = KernelArgument{type_convert(p_As[i]), + type_convert(p_Bs[i]), + type_convert(p_Es[i]), + M, + N, + K, + stride_a, + stride_b, + stride_c, + m_padded, + n_padded, + k_padded, + k0, + K_BATCH}; + + gemm_kernel_args_.emplace_back( + std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end); + } + } + + /** + * @brief Recalculate group grid size for all gemms and update B2C maps. + * + * @param[in] kbatch The new splitK parameter value. + */ + void UpdateKBatch(index_t kbatch) + { + K_BATCH = kbatch; + grid_size_ = 0; + + for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i) + { + + auto& karg = gemm_kernel_args_[i].karg_; + + const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH); + const index_t k0 = GridwiseGemm::CalculateK0(karg.K, K_BATCH); + + const auto c_grid_desc_m_n = + GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); + + const auto local_b2c_tile_map = + Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; + const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); + + const index_t block_start = grid_size_; + const index_t block_end = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + // block-to-e-tile map + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + karg.KPadded = k_padded; + karg.K0 = k0; + karg.k_batch = K_BATCH; + gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map; + gemm_kernel_args_[i].block_start_ = block_start; + gemm_kernel_args_[i].block_end_ = block_end; + } + } + + // private: + index_t K_BATCH; + index_t group_count_; + index_t skipped_group_count_; + + std::vector gemm_kernel_args_; + index_t grid_size_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + index_t K0 = arg.gemm_kernel_args_[0].karg_.K0; + bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.k_batch > 1; + bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) + { + const auto& karg = arg.gemm_kernel_args_[i].karg_; + if(stream_config.log_level_ > 0) + { + karg.Print(); + } + + auto kbatch = karg.k_batch; + + if(!GridwiseGemm::CheckValidity(karg)) + { + std::ostringstream err; + err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + K0 = karg.K0; + bool not_all_have_main_k0_block_loop_same = + all_have_main_k0_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1); + + if(not_all_have_main_k0_block_loop_same) + { + std::ostringstream err; + err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + if(not_all_have_kbatch_value_same) + { + std::ostringstream err; + err << "Not all gemms have same kbatch value (=1 or >1)! " + << "group [" << i << "], kbatch: " << kbatch + << ", group [0], kbatch: " << arg.gemm_kernel_args_[0].karg_.k_batch + << " in " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + + hip_check_error( + hipMemcpyWithStream(arg.p_workspace_, + arg.gemm_kernel_args_.data(), + arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg), + hipMemcpyHostToDevice, + stream_config.stream_id_)); + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + if(all_have_kbatch_gt_one) + { + for(const auto& trans_arg : arg.gemm_kernel_args_) + { + const auto& karg = trans_arg.karg_; + hip_check_error( + hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(EDataType))); + } + } + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.p_workspace_), + arg.gemm_kernel_args_.size()); + }; + + if(all_have_main_k0_block_loop) + { + if(all_have_kbatch_gt_one) + { + const auto kernel = + kernel_grouped_gemm_xdl_splitk; + + Run(kernel); + } + else + { + const auto kernel = + kernel_grouped_gemm_xdl_splitk; + + Run(kernel); + } + } + else + { + if(all_have_kbatch_gt_one) + { + const auto kernel = + kernel_grouped_gemm_xdl_splitk; + + Run(kernel); + } + else + { + const auto kernel = + kernel_grouped_gemm_xdl_splitk; + + Run(kernel); + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if((ck::type_convert(arg.gemm_kernel_args_.size()) + + arg.skipped_group_count_) != arg.group_count_) + { +#if DEBUG_LOG + std::cout << "The group count is not equal to sum of skipped groups " + "and kernel args size!" + << std::endl; +#endif // DEBUG_LOG + return false; + } + + bool supported = true; + for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) + { + const auto& a = arg.gemm_kernel_args_[i].karg_; + bool group_arg_valid = GridwiseGemm::CheckValidity(a); + if(not group_arg_valid) + { +#if DEBUG_LOG + std::cout << "[" << __func__ << "] group id: " << i + << " has invalid GridwiseGemm settings!" << std::endl; + a.Print(); +#endif // DEBUG_LOG + } + supported = supported && group_arg_valid; + } + return supported; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>&, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation) + { + return Argument{p_As, p_Bs, p_Es, gemm_descs}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>&, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation) override + { + return std::make_unique(p_As, p_Bs, p_Es, gemm_descs); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemm_XdlSplitK" + << "<" + << std::string(ALayout::name)[0] << "," + << std::string(BLayout::name)[0] << "," + << std::string(ELayout::name)[0] << "," + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + return dynamic_cast(p_arg)->gemm_kernel_args_.size() * + sizeof(GemmTransKernelArg); + } + + static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); } + + // polymorphic + void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override + { + return SetKBatchSize(*dynamic_cast(p_arg), kbatch); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_index_pool_bwd_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_index_pool_bwd_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..175994d49f3a3121e2e32a4ca2ef07e60c0e52e5 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_index_pool_bwd_impl.hpp @@ -0,0 +1,316 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_index_pool_bwd.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/stream_utility.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// output[indices] = input +template +struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd +{ + using DInDataType_AutomicAddPreCast = + conditional_t || is_same_v, + DInDataType, + float>; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using UnaryConvert = ck::tensor_operation::element_wise::UnaryConvert; + + static constexpr auto I0 = Number<0>{}; + + template + static auto PadDescriptor_M_1d(Desc_M desc_m, index_t loop_step) + { + const auto m = desc_m.GetLength(I0); + const auto pad = math::integer_least_multiple(m, loop_step) - m; + const auto desc_m_pad = + transform_tensor_descriptor(desc_m, + make_tuple(make_right_pad_transform(m, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m_pad; + } + + static auto MakeDescriptor_M(index_t length, index_t loop_step) + { + const auto desc_m = make_naive_tensor_descriptor_packed(make_tuple(length)); + return PadDescriptor_M_1d(desc_m, loop_step); + } + + using InOutGrid1dDesc = decltype(MakeDescriptor_M(1, 1)); + + using GridwisePutElementSet = GridwisePutElement_1D; + + using GridwisePutElementAtomicAdd = GridwisePutElement_1D; + + using GridwiseCasting = GridwiseElementwise_1D, + Tuple, + Tuple, + Tuple, + UnaryConvert, + InOutVectorSize, + Sequence, + Sequence>; + + struct Argument : public BaseArgument + { + Argument(const DOutDataType* p_dout, + const IndexDataType* p_indices, + DInDataType* p_din, + index_t dout_length, + index_t din_length, + const std::vector& window_lengths, + const std::vector& window_strides) + : p_dout_{p_dout}, + p_indices_{p_indices}, + p_din_{p_din}, + dout_length_raw_{dout_length}, + din_length_raw_{din_length}, + blockSize_{256}, + windowOverlap_{false} + { + for(size_t i = 0; i < window_lengths.size(); ++i) + { + windowOverlap_ |= window_lengths.at(i) > window_strides.at(i); + } + } + + const DOutDataType* p_dout_; + const IndexDataType* p_indices_; + DInDataType* p_din_; + index_t dout_length_raw_; + index_t din_length_raw_; + index_t blockSize_; + bool windowOverlap_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + index_t gridSize = getAvailableComputeUnitCount(stream_config); + index_t loop_step = gridSize * arg.blockSize_ * InOutVectorSize; + InOutGrid1dDesc din_grid_desc = MakeDescriptor_M(arg.din_length_raw_, loop_step); + InOutGrid1dDesc dout_grid_desc = MakeDescriptor_M(arg.dout_length_raw_, loop_step); + + if constexpr(is_same_v || is_same_v) + { + hip_check_error(hipMemsetAsync(arg.p_din_, + 0, + arg.din_length_raw_ * sizeof(DInDataType), + stream_config.stream_id_)); + + if(arg.windowOverlap_) + { + const auto put_kernel = kernel_put_element_1d; + + return launch_and_time_kernel(stream_config, + put_kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + dout_grid_desc, + arg.p_dout_, + arg.p_indices_, + arg.p_din_, + PassThrough{}); + } + else + { + const auto put_kernel = kernel_put_element_1d; + + return launch_and_time_kernel(stream_config, + put_kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + dout_grid_desc, + arg.p_dout_, + arg.p_indices_, + arg.p_din_, + PassThrough{}); + } + } + else + { + if(arg.windowOverlap_) + { + if(arg.p_workspace_ == nullptr) + throw std::runtime_error("wrong! WorkSpace pointer has not been set"); + + hip_check_error( + hipMemsetAsync(arg.p_workspace_, + 0, + arg.din_length_raw_ * sizeof(DInDataType_AutomicAddPreCast), + stream_config.stream_id_)); + + const auto put_kernel = kernel_put_element_1d; + + const auto cast_kernel = + kernel_elementwise_1d, + Tuple, + Tuple, + Tuple, + UnaryConvert>; + + float elapsed_time = launch_and_time_kernel( + stream_config, + put_kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + dout_grid_desc, + arg.p_dout_, + arg.p_indices_, + static_cast(arg.p_workspace_), + PassThrough{}); + + elapsed_time += launch_and_time_kernel( + stream_config, + cast_kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + ck::make_tuple(din_grid_desc), + ck::make_tuple(din_grid_desc), + static_cast(arg.p_workspace_), + arg.p_din_, + UnaryConvert{}); + + return elapsed_time; + } + else + { + const auto put_kernel = kernel_put_element_1d; + + hip_check_error(hipMemsetAsync(arg.p_din_, + 0, + arg.din_length_raw_ * sizeof(DInDataType), + stream_config.stream_id_)); + + return launch_and_time_kernel(stream_config, + put_kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + dout_grid_desc, + arg.p_dout_, + arg.p_indices_, + arg.p_din_, + PassThrough{}); + } + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + size_t GetWorkSpaceSize(const BaseArgument* pArg) const override + { + const Argument* pArg_ = dynamic_cast(pArg); + + bool needCast = pArg_->windowOverlap_ && + !(is_same_v || is_same_v); + + if(!needCast) + return 0; + else + return pArg_->din_length_raw_ * sizeof(DInDataType_AutomicAddPreCast); + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + if(pArg->din_length_raw_ % InOutVectorSize != 0 || + pArg->dout_length_raw_ % InOutVectorSize != 0) + { + return false; + } + return true; + } + + std::unique_ptr + MakeArgumentPointer(const void* p_dout, + const void* p_indices, + void* p_din, + index_t dout_length, + index_t din_length, + std::vector window_lengths, + std::vector window_strides) override + { + // Assume p_dout, p_indices, p_din are packed memory space, dout_length and din_length are + // physical size of the packed tensor + return std::make_unique(static_cast(p_dout), + static_cast(p_indices), + static_cast(p_din), + dout_length, + din_length, + window_lengths, + window_strides); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp b/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp index b49e1096829695f7f7d5b0f651a20f7aebc03f6b..aec5a65ccf41eafcbe6ebdf790cc75b0adbf2966 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp b/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp index 17a96e9f6f6c8039485afcab82d5908d92a59a24..6d1d5c8e239de8f57df1baa82d4981e8f30e54b6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp index bb62332d1ad54eca7fbb8d02f4b511b989530426..ea0d805043e07fe3d818fbb7bf7ef42e952496a9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -10,8 +10,7 @@ #include "ck/tensor_operation/gpu/device/device_normalization.hpp" #include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -20,6 +19,10 @@ namespace tensor_operation { namespace device { // Y = Normalization(X, Beta, Gamma) +// M: Invarient length +// K: Reduce length (Calculate mean and variance along K dimension) +// eg. Length = [N, C, H, W], reduce dim = [C, H, W] +// Then, M = N, K = C * H * W template & inLengths, const std::vector& inStrides, - int blkGroupSize, int numBlockTileIteration) { constexpr index_t NumInvariantDim = Rank - NumReduceDim; @@ -117,10 +119,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization{}); const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); - const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration; const auto inPad_M = math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; - const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength; + const auto inPad_K = K_BlockTileSize * numBlockTileIteration - reduceLength; auto in_grid_desc_m_k_padded = transform_tensor_descriptor( in_grid_desc_m_k, @@ -132,7 +133,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization(gammaStrides, reduceDims); betaStrides_ = shuffle_tensor_dimensions(betaStrides, reduceDims); - long_index_t invariant_total_length; - long_index_t reduce_total_length; + long_index_t invariant_length; + long_index_t reduce_length; - std::tie(invariant_total_length, reduce_total_length) = + std::tie(invariant_length, reduce_length) = get_2d_lengths(Lengths_); - blkGroupSize_ = 1; - numBlockTileIteration_ = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize; + numBlockTileIteration_ = math::integer_divide_ceil(reduce_length, K_BlockTileSize); - gridSize_ = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / - M_BlockTileSize * blkGroupSize_; + gridSize_ = math::integer_divide_ceil(invariant_length, M_BlockTileSize); - x_grid_desc_m_k_ = - MakeSrc2dDescriptor(Lengths_, xStrides_, blkGroupSize_, numBlockTileIteration_); + x_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, xStrides_, numBlockTileIteration_); gamma_grid_desc_m_k_ = - MakeSrc2dDescriptor(Lengths_, gammaStrides_, blkGroupSize_, numBlockTileIteration_); + MakeSrc2dDescriptor(Lengths_, gammaStrides_, numBlockTileIteration_); beta_grid_desc_m_k_ = - MakeSrc2dDescriptor(Lengths_, betaStrides_, blkGroupSize_, numBlockTileIteration_); - y_grid_desc_m_k_ = - MakeSrc2dDescriptor(Lengths_, yStrides_, blkGroupSize_, numBlockTileIteration_); + MakeSrc2dDescriptor(Lengths_, betaStrides_, numBlockTileIteration_); + y_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, yStrides_, numBlockTileIteration_); isSweeponce_ = x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize; @@ -202,7 +199,6 @@ struct DeviceNormalizationImpl : public DeviceNormalizationinvariant_lowest_length % XSrcVectorSize != 0) return false; + + if(p_arg_->invariant_lowest_length % YDstVectorSize != 0) + return false; }; } else @@ -295,12 +294,12 @@ struct DeviceNormalizationImpl : public DeviceNormalizationLengths_[Rank - 1] % XSrcVectorSize != 0) return false; - }; - if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0) - { - return false; - } + if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0) + { + return false; + } + }; // if fastest dim is not reduced if constexpr(GammaSrcVectorDim == 0) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8b2b3c41bfdf3f4793263dd65fc5f6dc941eb13b --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp @@ -0,0 +1,658 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +template +__global__ void +kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k, + const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, + index_t num_k_block_tile_iteration, + const XDataType* const __restrict__ p_x_global, + MeanVarDataType* const __restrict__ p_welford_mean, + MeanVarDataType* const __restrict__ p_welford_variance, + int32_t* const __restrict__ p_welford_count) +{ + GridwiseWelford::Run(x_grid_desc_m_k, + mean_var_grid_desc_m_kblock, + num_k_block_tile_iteration, + p_x_global, + p_welford_mean, + p_welford_variance, + p_welford_count); +}; + +template +__global__ void +kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, + const CountGridDesc_M_KBlock count_grid_desc_m_kblock, + const XYGammaBetaGridDesc_M_K x_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K gamma_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K beta_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K y_grid_desc_m_k, + index_t num_k_mean_var_count_iteration, + index_t num_k_block_tile_iteration, + index_t k_grid_size, + ComputeDataType epsilon, + const MeanVarDataType* const p_mean_global, + const MeanVarDataType* const p_variance_global, + const int32_t* const p_welford_count_global, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + const YElementwiseOperation y_elementwise_op) +{ + GridwiseWelfordNormalization::Run(mean_var_grid_desc_m_kblock, + count_grid_desc_m_kblock, + x_grid_desc_m_k, + gamma_grid_desc_m_k, + beta_grid_desc_m_k, + y_grid_desc_m_k, + num_k_mean_var_count_iteration, + num_k_block_tile_iteration, + k_grid_size, + epsilon, + p_mean_global, + p_variance_global, + p_welford_count_global, + p_x_global, + p_gamma_global, + p_beta_global, + p_y_global, + y_elementwise_op); +}; +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// Y = Normalization(X, Beta, Gamma) +// M: Invarient length +// K: Reduce length (Calculate mean and variance along K dimension) +// eg. Length = [N, C, H, W], reduce dim = [C, H, W] +// Then, M = N, K = C * H * W +template +struct DeviceNormalizationSplitKImpl : public DeviceNormalization +{ + using MeanVarDataType = ComputeDataType; + + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); + static_assert( + ((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) || + (GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)), + "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"); + + static_assert( + ((BetaSrcVectorDim == 0 && MThreadSliceSize % BetaSrcVectorSize == 0) || + (BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)), + "Invalid thread slice sizes and/or beta vector sizes configuration, please check!"); + + using PassThrough = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto MakeSrc2dDescriptor(const std::vector& inLengths, + const std::vector& inStrides, + int kBlockSize, + int numBlockTileIteration) + { + constexpr index_t NumInvariantDim = Rank - NumReduceDim; + static constexpr index_t numSrcDim = Rank; + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + + const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); + const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); + const auto invariantDimLengths = + make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration; + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = reduceSizePerBlock * kBlockSize - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + template + static auto MakeMeanVarDescriptor_M_K(index_t M, index_t K) + { + const auto grid_desc_m_k = + make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1)); + return PadTensorDescriptor(grid_desc_m_k, make_tuple(MPerTile, KPerTile), DoPads{}); + } + + template + static auto MakeCountDescriptor_M_K(index_t M, index_t K) + { + const auto grid_desc_m_k = + make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I0, I1)); + return PadTensorDescriptor(grid_desc_m_k, make_tuple(MPerTile, KPerTile), DoPads{}); + } + + using SrcGridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)); + using Kernel1MeanVarGridDesc_M_KBlock = + decltype(MakeMeanVarDescriptor_M_K, 1, 1>(1, 1)); + + using Kernel2MeanVarGridDesc_M_KBlock = + decltype(MakeMeanVarDescriptor_M_K, 1, 1>(1, 1)); + + using Kernel2CountGridDesc_M_KBlock = + decltype(MakeCountDescriptor_M_K, 1, 1>(1, 1)); + + using GridwiseWelford = GridwiseNormalizationSplitK1st; + + using GridwiseWelfordNormalization = + GridwiseNormalizationSplitK2nd; + + struct Argument : public BaseArgument + { + Argument(const std::vector lengths, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector reduceDims, + YElementwiseOperation y_elementwise_op, + double epsilon, + const XDataType* p_x, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + YDataType* p_y) + : p_x_(p_x), + p_gamma_(p_gamma), + p_beta_(p_beta), + p_y_(p_y), + p_workspace_mean_{nullptr}, + p_workspace_var_{nullptr}, + p_workspace_count_{nullptr}, + y_elementwise_op_(y_elementwise_op) + { + epsilon_ = static_cast(epsilon); + + Lengths_ = shuffle_tensor_dimensions(lengths, reduceDims); + xStrides_ = shuffle_tensor_dimensions(xStrides, reduceDims); + yStrides_ = shuffle_tensor_dimensions(yStrides, reduceDims); + gammaStrides_ = shuffle_tensor_dimensions(gammaStrides, reduceDims); + betaStrides_ = shuffle_tensor_dimensions(betaStrides, reduceDims); + + std::tie(MRaw_, KRaw_) = get_2d_lengths(Lengths_); + + numBlockTileIteration_ = 1; + while(true) + { + int testKGridSize = + math::integer_divide_ceil(KRaw_, K_BlockTileSize * numBlockTileIteration_); + + // we want the kGridSize_ be not more than 128 + if(testKGridSize <= 128) + break; + + ++numBlockTileIteration_; + }; + + kGridSize_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize * numBlockTileIteration_); + gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize) * kGridSize_; + + // We do not use vector load for mean, var and count + static constexpr index_t K_MeanVarCountBlockTileSize = KThreadClusterSize; + + numMeanVarCountIteration_ = + math::integer_divide_ceil(kGridSize_, K_MeanVarCountBlockTileSize); + + x_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, xStrides_, kGridSize_, numBlockTileIteration_); + gamma_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, gammaStrides_, kGridSize_, numBlockTileIteration_); + beta_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, betaStrides_, kGridSize_, numBlockTileIteration_); + y_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, yStrides_, kGridSize_, numBlockTileIteration_); + + // We don't need to pad in K dimension for Welford1. Set KPerTile 1. + kernel1_mean_var_grid_desc_m_kblock_ = + MakeMeanVarDescriptor_M_K, M_BlockTileSize, 1>(MRaw_, + kGridSize_); + + kernel2_mean_var_grid_desc_m_kblock_ = + MakeMeanVarDescriptor_M_K, + M_BlockTileSize, + K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_); + + kernel2_count_grid_desc_m_kblock_ = + MakeCountDescriptor_M_K, + M_BlockTileSize, + K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_); + } + + ComputeDataType epsilon_; + + const XDataType* p_x_; + const GammaDataType* p_gamma_; + const BetaDataType* p_beta_; + YDataType* p_y_; + void* p_workspace_mean_; + void* p_workspace_var_; + void* p_workspace_count_; + + std::vector Lengths_; + std::vector xStrides_; + std::vector gammaStrides_; + std::vector betaStrides_; + std::vector yStrides_; + + YElementwiseOperation y_elementwise_op_; + + int kGridSize_; + int numMeanVarCountIteration_; + int numBlockTileIteration_; + size_t gridSize_; + + SrcGridDesc_M_K x_grid_desc_m_k_; + SrcGridDesc_M_K gamma_grid_desc_m_k_; + SrcGridDesc_M_K beta_grid_desc_m_k_; + SrcGridDesc_M_K y_grid_desc_m_k_; + + Kernel1MeanVarGridDesc_M_KBlock kernel1_mean_var_grid_desc_m_kblock_; + Kernel2MeanVarGridDesc_M_KBlock kernel2_mean_var_grid_desc_m_kblock_; + Kernel2CountGridDesc_M_KBlock kernel2_count_grid_desc_m_kblock_; + + index_t MRaw_; // invarient length + index_t KRaw_; // reduce length + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(arg.p_workspace_mean_ == nullptr || arg.p_workspace_var_ == nullptr || + arg.p_workspace_count_ == nullptr) + throw std::runtime_error("wrong! WorkSpace pointer has not been set"); + + auto kernel1 = kernel_normalizationSplitK1st; + + auto kernel2 = kernel_normalizationSplitK2nd; + + float avg_time = 0; + avg_time += launch_and_time_kernel(stream_config, + kernel1, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.kernel1_mean_var_grid_desc_m_kblock_, + arg.numBlockTileIteration_, + arg.p_x_, + static_cast(arg.p_workspace_mean_), + static_cast(arg.p_workspace_var_), + static_cast(arg.p_workspace_count_)); + + avg_time += launch_and_time_kernel(stream_config, + kernel2, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.kernel2_mean_var_grid_desc_m_kblock_, + arg.kernel2_count_grid_desc_m_kblock_, + arg.x_grid_desc_m_k_, + arg.gamma_grid_desc_m_k_, + arg.beta_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + arg.numMeanVarCountIteration_, + arg.numBlockTileIteration_, + arg.kGridSize_, + arg.epsilon_, + static_cast(arg.p_workspace_mean_), + static_cast(arg.p_workspace_var_), + static_cast(arg.p_workspace_count_), + arg.p_x_, + arg.p_gamma_, + arg.p_beta_, + arg.p_y_, + arg.y_elementwise_op_); + + return avg_time; + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + size_t GetWorkSpaceSize(const BaseArgument* pArg) const override + { + const Argument* pArg_ = dynamic_cast(pArg); + + size_t workspace_size = 0; + + int welford_size = pArg_->MRaw_ * pArg_->kGridSize_; + + // workspace for welford intermediate mean + workspace_size += welford_size * sizeof(MeanVarDataType) + 64; + + // workspace for welford intermediate variance + workspace_size += welford_size * sizeof(MeanVarDataType) + 64; + + // workspace for welford intermediate count + workspace_size += pArg_->kGridSize_ * sizeof(int32_t) + 64; + + return (workspace_size); + }; + + void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override + { + Argument* pArg_ = dynamic_cast(pArg); + + pArg_->p_workspace_ = p_workspace; + + int welford_size = pArg_->MRaw_ * pArg_->kGridSize_; + + // setup buffer used for intermediate welford mean + pArg_->p_workspace_mean_ = static_cast(pArg_->p_workspace_); + + index_t mean_space_sz = welford_size * sizeof(MeanVarDataType); + mean_space_sz = math::integer_least_multiple(mean_space_sz, 64); + + // setup buffer used for intermediate welford varirance + pArg_->p_workspace_var_ = reinterpret_cast(pArg_->p_workspace_mean_) + mean_space_sz; + + index_t variance_space_sz = welford_size * sizeof(MeanVarDataType); + variance_space_sz = math::integer_least_multiple(variance_space_sz, 64); + + // setup buffer used for intermediate welford count + pArg_->p_workspace_count_ = + reinterpret_cast(pArg_->p_workspace_var_) + variance_space_sz; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* p_arg_ = dynamic_cast(p_arg); + + constexpr index_t NumInvariantDim = Rank - NumReduceDim; + + if constexpr(XYVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return false; + } + else + { + if(p_arg_->xStrides_[NumInvariantDim - 1] != 1) + return false; + + if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0) + return false; + + if(p_arg_->invariant_lowest_length % YDstVectorSize != 0) + return false; + }; + } + else + { + if(p_arg_->xStrides_[Rank - 1] != 1) + return false; + + if(p_arg_->Lengths_[Rank - 1] % XSrcVectorSize != 0) + return false; + + if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0) + return false; + }; + + // if fastest dim is not reduced + if constexpr(GammaSrcVectorDim == 0) + { + if(p_arg_->gammaStrides_[NumInvariantDim - 1] != 1) + return false; + + if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0) + return false; + } + else // if fastest dim is reduced + { + if(p_arg_->gammaStrides_[Rank - 1] != 1) + return false; + + if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0) + return false; + } + + // if fastest dim is not reduced + if constexpr(BetaSrcVectorDim == 0) + { + if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1) + return false; + + if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0) + return false; + } + else // if fastest dim is reduced + { + if(p_arg_->betaStrides_[Rank - 1] != 1) + return false; + + if(p_arg_->Lengths_[Rank - 1] % BetaSrcVectorSize != 0) + return false; + } + + if(p_arg_->kGridSize_ <= 1) + return false; + + return true; + }; + + std::unique_ptr + MakeArgumentPointer(const std::vector lengths, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector reduceDims, + double epsilon, + const void* p_x, + const void* p_gamma, + const void* p_beta, + void* p_y, + void* p_saveMean, + void* p_saveInvVar, + YElementwiseOperation y_elementwise_op) override + { + // TODO + // Optional cache of the intermediate results (mean and InvVariance) during the + // forward pass could speedup in the backward + ignore = p_saveMean; + ignore = p_saveInvVar; + + return std::make_unique(lengths, + xStrides, + gammaStrides, + betaStrides, + yStrides, + reduceDims, + y_elementwise_op, + epsilon, + static_cast(p_x), + static_cast(p_gamma), + static_cast(p_beta), + static_cast(p_y)); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceNormalizationSplitKImpl<" << BlockSize << ","; + str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ","; + str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ","; + str << "XYSrcVectorDim_" << XYVectorDim << ","; + str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp index 7b96373c0ff1fc86a79c4248db35dd5812f2c5b1..17dab08332166e23a93d5c8f25ecfec7556519ac 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp b/include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp index bfde40cda2a5369015b08de568f72c786295848f..3f27c629dd6ac88a41b0072a1f51afc3b3fb3b25 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -9,7 +9,7 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" -#include "ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp" +#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -20,16 +20,18 @@ namespace device { template -struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd +struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C + : public DevicePoolFwd<4, 2, InDataType, OutDataType, IndexDataType, ReduceOpId, OutputIndex> { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -38,7 +40,8 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd static constexpr auto I4 = Number<4>{}; static constexpr auto I5 = Number<5>{}; - using IndexDataType = int32_t; + static constexpr index_t InOutRank = 4; + static constexpr index_t WindowRank = 2; using ReduceOperation = typename reduce_binary_operator::opType; @@ -59,12 +62,12 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd static auto MakeABGridDescriptor_A_M_K_B_M(ck::index_t N, ck::index_t C, - std::array input_spatial_lengths, - std::array window_spatial_lengths, - std::array output_spatial_lengths, - std::array window_strides, - std::array input_left_pads, - std::array input_right_pads) + std::vector input_spatial_lengths, + std::vector window_spatial_lengths, + std::vector output_spatial_lengths, + std::vector window_strides, + std::vector input_left_pads, + std::vector input_right_pads) { const index_t Hi = input_spatial_lengths[0]; const index_t Wi = input_spatial_lengths[1]; @@ -141,9 +144,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem); } - using ABGridDescs = decltype( - MakeABGridDescriptor_A_M_K_B_M(1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); - + using ABGridDescs = decltype(MakeABGridDescriptor_A_M_K_B_M(1, 1, {}, {}, {}, {}, {}, {})); using AGridDesc_M_K = remove_cvref_t; using BGridDesc_M = remove_cvref_t; @@ -152,15 +153,15 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd { Argument(const InDataType* p_in_dev, OutDataType* p_out_dev, - int* p_out_indices_dev, + IndexDataType* p_out_indices_dev, ck::index_t N, ck::index_t C, - std::array& input_spatial_lengths, - std::array& window_spatial_lengths, - std::array& output_spatial_lengths, - std::array& window_strides, - std::array& input_left_pads, - std::array& input_right_pads) + std::vector& input_spatial_lengths, + std::vector& window_spatial_lengths, + std::vector& output_spatial_lengths, + std::vector& window_strides, + std::vector& input_left_pads, + std::vector& input_right_pads) : p_in_dev_{p_in_dev}, p_out_dev_{p_out_dev}, p_out_indices_dev_{p_out_indices_dev}, @@ -190,7 +191,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd const InDataType* p_in_dev_; OutDataType* p_out_dev_; - int* p_out_indices_dev_; + IndexDataType* p_out_indices_dev_; AGridDesc_M_K a_grid_desc_m_k_; BGridDesc_M b_grid_desc_m_; InElementwiseOperation in_element_op_; @@ -208,7 +209,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise; - const auto kernel = kernel_reduce_threadwise; + const auto kernel = + kernel_reduce_threadwise; ck::index_t ReduceM = arg.a_grid_desc_m_k_.GetLength(I0); @@ -280,22 +283,42 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd MakeArgumentPointer(const void* p_in_dev, void* p_out_dev, void* p_out_indices_dev, - ck::index_t N, - ck::index_t C, - std::array input_spatial_lengths, - std::array window_spatial_lengths, - std::array output_spatial_lengths, - std::array window_strides, - std::array input_left_pads, - std::array input_right_pads) override + std::vector input_lengths, + std::vector window_lengths, + std::vector output_lengths, + std::vector, // Suppose tensor layout = NHWC + std::vector, // Suppose tensor layout = NHWC + std::vector, // Suppose tensor layout = NHWC + std::vector window_strides, + std::vector input_left_pads, + std::vector input_right_pads, + std::vector pooling_dims) override { + if(input_lengths.size() != InOutRank || window_lengths.size() != WindowRank || + input_lengths.size() != InOutRank || window_strides.size() != WindowRank || + input_left_pads.size() != WindowRank || input_right_pads.size() != WindowRank) + throw std::runtime_error("dimension is incorrect"); + + if(pooling_dims != std::vector{2, 3}) + throw std::runtime_error("pooling_dims only support {2, 3} in pool2d so far"); + + index_t N = input_lengths[0]; + index_t C = input_lengths[1]; + index_t Hi = input_lengths[2]; + index_t Wi = input_lengths[3]; + index_t Ho = output_lengths[2]; + index_t Wo = output_lengths[3]; + + std::vector input_spatial_lengths = {Hi, Wi}; + std::vector output_spatial_lengths = {Ho, Wo}; + return std::make_unique(static_cast(p_in_dev), static_cast(p_out_dev), - static_cast(p_out_indices_dev), + static_cast(p_out_indices_dev), N, C, input_spatial_lengths, - window_spatial_lengths, + window_lengths, output_spatial_lengths, window_strides, input_left_pads, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp b/include/ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0ab6c247584881bd72723b6cda2f412aea9a2e1c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp @@ -0,0 +1,357 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C + : public DevicePoolFwd<5, 3, InDataType, OutDataType, IndexDataType, ReduceOpId, OutputIndex> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr index_t InOutRank = 5; + static constexpr index_t WindowRank = 3; + + using ReduceOperation = typename reduce_binary_operator::opType; + + using InElementwiseOperation = + typename reduce_unary_operator::InElementwiseOperation; + + using AccElementwiseOperation = + typename reduce_unary_operator::AccElementwiseOperation; + + // for NDHWC, the dim C is the vector Dim for both input and output in memory, which is not + // reduced. + static constexpr index_t InSrcOutDstVectorDim = 0; + + static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto MakeABGridDescriptor_A_M_K_B_M(ck::index_t N, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector window_spatial_lengths, + std::vector output_spatial_lengths, + std::vector window_strides, + std::vector input_left_pads, + std::vector input_right_pads) + { + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t Z = window_spatial_lengths[0]; + const index_t Y = window_spatial_lengths[1]; + const index_t X = window_spatial_lengths[2]; + + const index_t ConvStrideD = window_strides[0]; + const index_t ConvStrideH = window_strides[1]; + const index_t ConvStrideW = window_strides[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t MRaw = N * Do * Ho * Wo * C; + const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw; + + const index_t KRaw = Z * Y * X; + const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw; + + // A[ReduceM, ReduceK] + const auto in_grid_desc_n_di_hi_wi_c = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + + const auto in_grid_desc_n_dip_hip_wip_c = transform_tensor_descriptor( + in_grid_desc_n_di_hi_wi_c, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_grid_desc_n_z_do_y_ho_x_wo_c = transform_tensor_descriptor( + in_grid_desc_n_dip_hip_wip_c, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(I1, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_grid_desc_reducemraw_reducekraw = transform_tensor_descriptor( + in_grid_desc_n_z_do_y_ho_x_wo_c, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, C)), + make_merge_transform(make_tuple(Z, Y, X))), + make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_grid_desc_reducem_reducek = transform_tensor_descriptor( + in_grid_desc_reducemraw_reducekraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // B[ReduceM] + const auto out_grid_desc_reducemraw = + make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo * C)); + + const auto out_grid_desc_reducem = + transform_tensor_descriptor(out_grid_desc_reducemraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + + return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem); + } + + using ABGridDescs = decltype(MakeABGridDescriptor_A_M_K_B_M(1, 1, {}, {}, {}, {}, {}, {})); + using AGridDesc_M_K = remove_cvref_t; + using BGridDesc_M = remove_cvref_t; + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_dev, + OutDataType* p_out_dev, + IndexDataType* p_out_indices_dev, + ck::index_t N, + ck::index_t C, + std::vector& input_spatial_lengths, + std::vector& window_spatial_lengths, + std::vector& output_spatial_lengths, + std::vector& window_strides, + std::vector& input_left_pads, + std::vector& input_right_pads) + : p_in_dev_{p_in_dev}, + p_out_dev_{p_out_dev}, + p_out_indices_dev_{p_out_indices_dev}, + a_grid_desc_m_k_{}, + b_grid_desc_m_{} + { + const auto descs = MakeABGridDescriptor_A_M_K_B_M(N, + C, + input_spatial_lengths, + window_spatial_lengths, + output_spatial_lengths, + window_strides, + input_left_pads, + input_right_pads); + + a_grid_desc_m_k_ = descs[I0]; + b_grid_desc_m_ = descs[I1]; + + invariant_lowest_length_ = C; + + int32_t reduceLength = + window_spatial_lengths[0] * window_spatial_lengths[1] * window_spatial_lengths[2]; + + std::tie(in_element_op_, acc_element_op_) = + reduce_unary_operator::GetElementwiseOperator(reduceLength); + } + + const InDataType* p_in_dev_; + OutDataType* p_out_dev_; + IndexDataType* p_out_indices_dev_; + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_M b_grid_desc_m_; + InElementwiseOperation in_element_op_; + AccElementwiseOperation acc_element_op_; + + // for checking vector load/store + ck::index_t invariant_lowest_length_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + using gridwise_reduce = + GridwiseReduction_mk_to_m_threadwise; + + const auto kernel = + kernel_reduce_threadwise; + + ck::index_t M = arg.a_grid_desc_m_k_.GetLength(I0); + + const index_t grid_size = (M / M_BlockTileSize); + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.a_grid_desc_m_k_, + arg.b_grid_desc_m_, + arg.in_element_op_, + arg.acc_element_op_, + float(1), + arg.p_in_dev_, + nullptr, + float(0), + arg.p_out_dev_, + arg.p_out_indices_dev_); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if(pArg->invariant_lowest_length_ % InSrcOutDstVectorSize != 0) + { + return false; + } + + return true; + } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_dev, + void* p_out_dev, + void* p_out_indices_dev, + std::vector input_lengths, + std::vector window_lengths, + std::vector output_lengths, + std::vector, // Suppose tensor layout = NDHWC + std::vector, // Suppose tensor layout = NDHWC + std::vector, // Suppose tensor layout = NDHWC + std::vector window_strides, + std::vector input_left_pads, + std::vector input_right_pads, + std::vector pooling_dims) override + { + if(input_lengths.size() != InOutRank || window_lengths.size() != WindowRank || + input_lengths.size() != InOutRank || window_strides.size() != WindowRank || + input_left_pads.size() != WindowRank || input_right_pads.size() != WindowRank) + throw std::runtime_error("dimension is incorrect"); + + if(pooling_dims != std::vector{2, 3, 4}) + throw std::runtime_error("pooling_dims only support {2, 3, 4} in pool3d so far"); + + index_t N = input_lengths[0]; + index_t C = input_lengths[1]; + index_t Di = input_lengths[2]; + index_t Hi = input_lengths[3]; + index_t Wi = input_lengths[4]; + index_t Do = output_lengths[2]; + index_t Ho = output_lengths[3]; + index_t Wo = output_lengths[4]; + + std::vector input_spatial_lengths = {Di, Hi, Wi}; + std::vector output_spatial_lengths = {Do, Ho, Wo}; + + return std::make_unique(static_cast(p_in_dev), + static_cast(p_out_dev), + static_cast(p_out_indices_dev), + N, + C, + input_spatial_lengths, + window_lengths, + output_spatial_lengths, + window_strides, + input_left_pads, + input_right_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7334da0e35af5ce92665decc4897ebe48393b9c3 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp @@ -0,0 +1,155 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_put_element.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/stream_utility.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// output[indices] = input +template +struct DevicePutElementImpl + : public DevicePutElement +{ + template + static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize) + { + constexpr auto I0 = Number<0>{}; + + const auto m = desc_m.GetLength(I0); + const index_t loop_step = gridSize * blockSize * InVectorSize; + const auto pad = math::integer_least_multiple(m, loop_step) - m; + const auto desc_m_pad = + transform_tensor_descriptor(desc_m, + make_tuple(make_right_pad_transform(m, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m_pad; + } + + static auto MakeDescriptor_M(index_t length, index_t gridSize, index_t blockSize) + { + const auto desc_m = make_naive_tensor_descriptor_packed(make_tuple(length)); + return PadDescriptor_M_1d(desc_m, gridSize, blockSize); + } + + using InGrid1dDesc = decltype(MakeDescriptor_M(1, 1, 1)); + + using GridwisePutElement = GridwisePutElement_1D; + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_input, + const IndexDataType* p_indices, + OutDataType* p_output, + index_t input_length, + ElementwiseOperation elementwise_op) + : p_input_{p_input}, + p_indices_{p_indices}, + p_output_{p_output}, + input_length_raw_{input_length}, + elementwise_op_{elementwise_op}, + blockSize_{256} + { + } + + const InDataType* p_input_; + const IndexDataType* p_indices_; + OutDataType* p_output_; + index_t input_length_raw_; + ElementwiseOperation elementwise_op_; + index_t blockSize_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + index_t gridSize = getAvailableComputeUnitCount(stream_config); + InGrid1dDesc in_grid_desc = + MakeDescriptor_M(arg.input_length_raw_, gridSize, arg.blockSize_); + + const auto kernel = kernel_put_element_1d; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + in_grid_desc, + arg.p_input_, + arg.p_indices_, + arg.p_output_, + arg.elementwise_op_); + return elapsed_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if(pArg->input_length_raw_ % InVectorSize != 0) + { + return false; + } + return true; + } + + std::unique_ptr MakeArgumentPointer(const void* p_input, + const void* p_indices, + void* p_output, + index_t input_length, + index_t, + ElementwiseOperation elementwise_op) override + { + return std::make_unique(static_cast(p_input), + static_cast(p_indices), + static_cast(p_output), + input_length, + elementwise_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp index 5dc051be3cb7f23fef076921e826a7e55e017d76..2481c5c76971db6bb2ddc390510089af5f656ae6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp b/include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp index c7868537fe8220c640ba24d346eea5dc219490fa..bf3deeb57acbdff95995f5da096ab574140bcb42 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp b/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp index a1d976f1a17c698c538b0f7ba1b2b200162e8eb2..6c5895b010659cafb98e4a78bc6587d0222e02bd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -28,6 +28,7 @@ template + Rank, + NumReduceDim> { - static constexpr index_t kRank = Rank; - static constexpr index_t kNumReduceDim = NumReduceDim; - static constexpr index_t kNumInvariantDim = Rank - NumReduceDim; - - virtual index_t GetRank() const override { return kRank; } - - virtual index_t GetNumReduceDim() const override { return kNumReduceDim; } - static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumSrcDim = Rank; @@ -287,13 +280,13 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax 0 && arg.invariant_lowest_length_ % OutDstVectorSize != 0) + if(NumInvariantDim > 0 && arg.invariant_lowest_length_ % OutDstVectorSize != 0) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp index 2f29224a754057430cd68bb551807226d40a6cb5..7a62ec04650a94d1b9f458e425ddf3b14881091d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/masking_specialization.hpp b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp index ea0f5897a757c408047b20e0cd4285e13f67c4f5..d6d6f74abdb5a63ef8a841d033c9f47ec59bd07e 100644 --- a/include/ck/tensor_operation/gpu/device/masking_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/matrix_padder.hpp b/include/ck/tensor_operation/gpu/device/matrix_padder.hpp index 70e61bc7728f9638e5bef874ab0d5caad2db5028..c66d2fc516babe3aa5a16de053d9ada89eabf8ef 100644 --- a/include/ck/tensor_operation/gpu/device/matrix_padder.hpp +++ b/include/ck/tensor_operation/gpu/device/matrix_padder.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp b/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp index d35318357a9bcc6fb087c92e27dcf435ff60796b..5351d4ef2439307ad5623f4172217937812ab29a 100644 --- a/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp +++ b/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp index b44427411f9a28511fd16f6b47ee37fa84a0e2d3..b2d141fd61aa4550ef8b572e548688120634b8b9 100644 --- a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp +++ b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/tensor_specialization.hpp b/include/ck/tensor_operation/gpu/device/tensor_specialization.hpp index 0ec0df2c9bbbb2aac987d681a36d5b00a8349c97..713fc93ebb4eec9ce0301affd667656bf0b6c999 100644 --- a/include/ck/tensor_operation/gpu/device/tensor_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/tensor_specialization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/welford_helper.hpp b/include/ck/tensor_operation/gpu/device/welford_helper.hpp index 6c909b767d4602914660d045b9aaf5a2307ea57b..d7772d8764ff3cc66572775f750997fae944800e 100644 --- a/include/ck/tensor_operation/gpu/device/welford_helper.hpp +++ b/include/ck/tensor_operation/gpu/device/welford_helper.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index 136017c6d17027a67fee3b493a20a8d3fabe2bd4..1dd96809d5f79ed4921fa6807625727279107df7 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index ceb2b665b91529f0a97cfb1209b2de40744ee466..3fdb391a0215e83328905b77148b675defc27af7 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/element/quantization_operation.hpp b/include/ck/tensor_operation/gpu/element/quantization_operation.hpp index 7ea09a2220474ba89de38995ed4d6e4c6b550fa9..fefa6c793f7e91e3744fe028469fd4359e4e6bbc 100644 --- a/include/ck/tensor_operation/gpu/element/quantization_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/quantization_operation.hpp @@ -7,10 +7,30 @@ namespace ck { namespace tensor_operation { namespace element_wise { +// Y = Sy * Qy +// W = Sw * Qw +// X = Sx * Qx +// B = Sb * Qb = Sw * Sx * Qb +// Where X, W, Y are float32, Qx, Qw, Qy are int8 +// Sx, Sw, Sy are scale of x, w, y (float32), which is calculated from quantization range +// Qb is int32, scale of B is Sw * Sx for convenient + +// Y = W @ X, where @ is convolution or matrix multiplication +// Sy * Qy = Sw * Qw @ Sx * Qx +// Qy = [(Sw*Sx)/Sy] * Qw @ Qx + // For Activation function which is piecewise linear function, such as relu, leaky relu ...etc +// Activation(Sy * Qy) = Sy * Activation(Qy) template struct Activation_Mul_Clamp { + // Convolution + Activation (piecewise linear function) + // If an activation is piecewise linear function, then Activation(Sy * Qy) = Sy * Activation(Qy) + // Z = Activation(Y) = Activation(W @ X) + // Sz * Qz = Activation(Sy * Qy) + // Qz = Sy / Sz * Activation(Qy) = (Sw * Sx / Sz) * Activation(Qw @ Qx) + + // requantScale_ = Sw * Sx / Sz Activation_Mul_Clamp(float requantScale, Activation activationOp) : requantScale_(requantScale), activationOp_(activationOp) { @@ -45,8 +65,39 @@ struct Activation_Mul_Clamp Activation activationOp_; }; +// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc +// If an activation is not piecewise linear function +// then Activation(Sy * Qy) != Sy * Activation(Qy) +template +struct Mul_Activation_Mul_Clamp +{ + // Convolution + Activation (non piecewise linear function) + // Z = Activation(Y) = Activation(W @ X) + // Sz * Qz = Activation(Sy * Qy) + // Qz = S1 * Activation[Sacc * (Qw @ Qx)] + // Where S1 = 1 / Sz, Sacc = Sw * Sx + Mul_Activation_Mul_Clamp(float scale_z_inv, float scaleAcc, Activation activationOp) + : scale_z_inv_(scale_z_inv), scaleAcc_(scaleAcc), activationOp_(activationOp) + { + } + + __host__ __device__ constexpr void operator()(int8_t& y, const int32_t& x) const + { + float y_fp32 = ck::type_convert(x); + y_fp32 = scaleAcc_ * y_fp32; + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + float scale_z_inv_; + float scaleAcc_; + Activation activationOp_; +}; + // Conv Perchannel quantization + Activation function which is piecewise linear function, such as // relu, leaky relu ...etc +// Activation(Sy * Qy) = Sy * Activation(Qy) template struct Activation_Mul2_Clamp { @@ -76,9 +127,20 @@ struct Activation_Mul2_Clamp }; // For Activation function which is piecewise linear function, such as relu, leaky relu ...etc +// Activation(Sy * Qy) = Sy * Activation(Qy) template struct Add_Activation_Mul_Clamp { + // Convolution + bias + // Let Bias = B = Sw * Sx * Qb + // Where Qb is int32 + // Y = W @ X + B + // Sy * Qy = Sw * Qw @ Sx * Qx + Sw * Sx * Qb + // Qy = [(Sw*Sx)/Sy] * (Qw @ Qx + Qb) + + // For activation, Z = Activaiton(Y) + // Sz * Qz = Activation(Sy * Qy) + // Qz = Sy / Sz * Activation(Qy) = [(Sw*Sx)/Sz] * Activation(Qw @ Qx + Qb) Add_Activation_Mul_Clamp(float requantScale, Activation activationOp) : requantScale_(requantScale), activationOp_(activationOp) { @@ -139,11 +201,18 @@ struct Add_Activation_Mul2_Clamp }; // For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc +// If an activation is not piecewise linear function +// then Activation(Sy * Qy) != Sy * Activation(Qy) template struct Add_Mul_Activation_Mul_Clamp { - Add_Mul_Activation_Mul_Clamp(float requantScale1, float requantScale2, Activation activationOp) - : requantScale1_(requantScale1), requantScale2_(requantScale2), activationOp_(activationOp) + // Convolution + Activation (non piecewise linear function) + // Z = Activation(Y) = Activation(W @ X + B) + // Sz * Qz = Activation(Sy * Qy) + // Qz = S1 * Activation[Sacc * (Qw @ Qx + Qb)] + // Where S1 = 1 / Sz, Sacc = Sw * Sx + Add_Mul_Activation_Mul_Clamp(float scale_z_inv, float scaleAcc, Activation activationOp) + : scale_z_inv_(scale_z_inv), scaleAcc_(scaleAcc), activationOp_(activationOp) { } @@ -151,14 +220,64 @@ struct Add_Mul_Activation_Mul_Clamp operator()(int8_t& y, const int32_t& x, const int32_t& bias) const { float y_fp32 = ck::type_convert(x + bias); - y_fp32 = requantScale1_ * y_fp32; + y_fp32 = scaleAcc_ * y_fp32; + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + __host__ __device__ constexpr void + operator()(int32_t& y, const int32_t& x, const int32_t& bias) const + { + // CAUSION - We might type_convert to int8 in threadwise copy + // eg. GridwiseGemmDlMultipleD_km_kn_mn + float y_fp32 = ck::type_convert(x + bias); + y_fp32 = scaleAcc_ * y_fp32; activationOp_(y_fp32, y_fp32); - y_fp32 = math::clamp(requantScale2_ * y_fp32, -128.f, 127.f); + y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + float scale_z_inv_; + float scaleAcc_; + Activation activationOp_; +}; + +// Conv Perchannel quantization + Activation function which is non piecewise linear function, +// such as TanH, Sigmoid ...etc +// If an activation is not piecewise linear function +// then Activation(Sy *Qy) != Sy * Activation(Qy) +template +struct Add_Mul2_Activation_Mul_Clamp +{ + Add_Mul2_Activation_Mul_Clamp(float scale_z_inv, Activation activationOp) + : scale_z_inv_(scale_z_inv), activationOp_(activationOp) + { + } + + __host__ __device__ constexpr void + operator()(int8_t& y, const int32_t& x, const int32_t& bias, const float& scaleAcc) const + { + float y_fp32 = ck::type_convert(x + bias); + y_fp32 = scaleAcc * y_fp32; + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f); y = ck::type_convert(y_fp32); } - float requantScale1_; - float requantScale2_; + __host__ __device__ constexpr void + operator()(int32_t& y, const int32_t& x, const int32_t& bias, const float& scaleAcc) const + { + // CAUSION - We might type_convert to int8 in threadwise copy + // eg. GridwiseGemmDlMultipleD_km_kn_mn + float y_fp32 = ck::type_convert(x + bias); + y_fp32 = scaleAcc * y_fp32; + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + float scale_z_inv_; Activation activationOp_; }; diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 6b4df3b60e363c5e3c7a14e4e15ecf6e48eff3ec..4fb061fadba3e136ded891149594550f91d3e783 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -1,11 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/utility/data_type.hpp" #include "ck/utility/math.hpp" #include "ck/utility/math_v2.hpp" +#include "ck/utility/type_convert.hpp" namespace ck { namespace tensor_operation { @@ -56,6 +57,12 @@ struct PassThrough y = type_convert(x); } + template <> + __host__ __device__ void operator()(bhalf_t& y, const half_t& x) const + { + y = type_convert(x); + } + template <> __host__ __device__ void operator()(int8_t& y, const int8_t& x) const { @@ -75,6 +82,36 @@ struct PassThrough y = x; } #endif + + template <> + __host__ __device__ void operator()(f8_t& y, const f8_t& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(float& y, const f8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(f8_t& y, const float& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(half_t& y, const f8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(f8_t& y, const half_t& x) const + { + y = type_convert(x); + } }; struct UnaryConvert @@ -86,6 +123,40 @@ struct UnaryConvert } }; +struct ConvertBF16RTN +{ + // convert to bf16 using round to nearest (rtn) + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + // check Y datatype + static_assert(is_same::value, "Data type is not supported by this operation!"); + + // check X datatype + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + y = bf16_convert_rtn(x); + } +}; + +struct ConvertF8SR +{ + // convert to fp8 using stochastic rounding (SR) + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + // check Y datatype + static_assert(is_same::value, "Data type is not supported by this operation!"); + + // check X datatype + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + y = f8_convert_sr(x); + } +}; + struct Scale { __host__ __device__ Scale(float scale) : scale_(scale) {} @@ -316,8 +387,36 @@ struct Sigmoid y = 1 / (ck::type_convert(1) + exp(-x)); }; +}; - int32_t divider_ = 1; +struct TanH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::tanh(x); + }; +}; + +struct Swish +{ + Swish(float beta = 1.0f) : beta_(beta) {} + + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = x / (ck::type_convert(1) + ck::math::exp(-beta_ * x)); + }; + + float beta_ = 1.0f; }; } // namespace element_wise diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp index a72a4ee068f2bbcb1a915df0368c9b05caa298f9..e73e7e681721c4558c7b58b0bce962b6e156ec58 100644 --- a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp index 08cb0dd191303813a9283b19e38e2952fa0ccf27..fc263138f9dd7ed3057e5259d015b70450abbf21 100644 --- a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp index 548d7fd40ac70fe7fde5e5973a177600b0941c5c..1f8990e6df835115ce0d471134553a14494c1b82 100644 --- a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp index 42b7e172b239e69ae0ca840dd80d8bb203a5480f..6fe78edb34df290de485f0d015c3ccd86e82b65b 100644 --- a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index fe4dce2b9dccaa9533b6182ba3259755d7ac8d8b..5c8b9f419ae11d62a23f83320a8659122fb5d52d 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -109,30 +109,57 @@ struct BlockToCTileMap_M00_N0_M01 // Rows of column-vectors // This C-tile map dynamically adjusts M01 when C-tile index is out of range -template -struct BlockToCTileMap_M00_N0_M01Adapt +template +struct BlockToCTileMap_M00_N0_M01Adapt; + +template +struct BlockToCTileMap_M00_N0_M01Adapt { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt() = default; + __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const BlockToCTileMap_M00_N0_M01Adapt&) = + default; + __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(BlockToCTileMap_M00_N0_M01Adapt&&) = + default; + __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt& + operator=(const BlockToCTileMap_M00_N0_M01Adapt&) = default; + __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt& + operator=(BlockToCTileMap_M00_N0_M01Adapt&&) = default; + + __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8) + : M_(M), N_(N), M01_(M01) + { + } + + template __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01 = 8) - : M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n) + : BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01) { } - __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + __host__ static constexpr index_t CalculateGridSize(index_t M, index_t N) { - const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); - const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0; + } - const index_t grid_size = M0 * N0; + template + __host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } - return grid_size; + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; } template @@ -140,8 +167,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt { auto block_1d_id = idx_top[I0]; - const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock); - const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock); + const auto M0 = math::integer_divide_ceil(M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock); block_1d_id = block_1d_id % (M0 * N0); // swallow batch index @@ -209,11 +236,18 @@ struct BlockToCTileMap_M00_N0_M01Adapt return true; // always valid provided that user gets grid size from CalculateGridSize() } - __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } - private: + index_t M_; + index_t N_; index_t M01_; - CGridDesc_M_N c_grid_desc_m_n_; +}; + +// keep the redundant type argument for backward compatibility +template +struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt +{ + using BlockToCTileMap_M00_N0_M01Adapt:: + BlockToCTileMap_M00_N0_M01Adapt; }; // 2D slices of column-vectors in 3D space @@ -587,4 +621,52 @@ struct OffsettedBlockToCTileMap index_t block_start_; }; +/** + * @brief Simple tile mapping which creates 3D grid of block of threads. + * + * @paragraph Description + * This Block-to-C-tile-map creates a 3D grid (n_blocks, m_blocks, z_blocks) of thread + * blocks. The first 2D are regular 2D tiles created by division of output GEMM + * dimenions by corresponding tile size. The third dimension (Z) is a k-split dimension, + * which denotes the number of blocks we use to divide work on GEMM K dimension onto. + * + * @tparam MPerBlock Output block tile size in M dimension. + * @tparam NPerBlock Output block tile size in N dimension. + */ +template +struct BlockToCTileMap_3DGrid_KSplit +{ + + __host__ __device__ BlockToCTileMap_3DGrid_KSplit() = default; + + __host__ __device__ constexpr auto + CalculateGridSize(index_t M, index_t N, index_t k_split) const + { + // Create 3D grid + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return std::make_tuple(N0, M0, k_split); + } + + template + __device__ constexpr auto CalculateBottomIndex(const TopIdx&) const + { + return make_tuple(blockIdx.z, blockIdx.y, blockIdx.x); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } +}; + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp index aa34cfbf84ab9425a236646400c1a965c8004f74..523e7f7c539da7ac6e5729c6e87b4a21f1bd3f6b 100644 --- a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp index fbe89e7e5e5936c2db1470e63f922cc7f6c5ff2a..69468c25befe036123cf1796798e6a59ea47f74b 100644 --- a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp index bdebe3816f22d6a4f10e913210abc8c8347e20d0..bd1e0585fc96ad29ee9e5336993f2ba6570fb31d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp index 1313ec9435e782ad9cc9b92580464f6d5257b93e..fc4f27e33b15b6bfcc7ce942764f9aa5e26cfe1f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp index 6836a66047531d9eef09dbffc58188f7696a783a..203be3c42d9d23c935d10f34ceb44245f7875ad0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp index 6c5bd29f9b5ee70b995342f04d59b8ed042b1194..910c926c7e470e0a7daa1cd25f57716f5ce4af6b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -15,6 +15,7 @@ namespace ck { template (in_grid_desc_m_k, - out_grid_desc_m, - in_elementwise_op, - acc_elementwise_op, - alpha, - p_in_value_global, - p_in_index_global, - beta, - p_out_value_global, - p_out_index_global); + GridwiseReduction::template RunWithIndex( + in_grid_desc_m_k, + out_grid_desc_m, + in_elementwise_op, + acc_elementwise_op, + alpha, + p_in_value_global, + p_in_index_global, + beta, + p_out_value_global, + p_out_index_global); }; }; @@ -232,7 +234,7 @@ struct GridwiseReduction_mk_to_m_threadwise reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf); }; - template + template __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k, const OutGridDesc_M& out_grid_desc_m, const InElementwiseOperation& in_elementwise_op, @@ -390,6 +392,18 @@ struct GridwiseReduction_mk_to_m_threadwise indexStart += KThreadSliceSize; reducedLength += KThreadSliceSize; } while(reducedLength < toReduceLength); + + if constexpr(TransformIndexKtoGlobal) + { + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + const auto coord = make_tensor_coordinate( + in_grid_desc_m_k, + make_multi_index(thread_global_1d_id * MThreadSliceSize + I, + accu_index_buf(I))); + + accu_index_buf(I) = coord.GetOffset(); + }); + } }; // for indiced operation, acc_elementwise_op shoud do nothing diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp index fccb127d0f342911ed6bb1f8396c9d9cd0eec921..a8a1f803fcb260e3eba6ccf301f68b7638e2108b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp index b9f4a3080a0e17ca0ba169bc8b83c006acb2305d..59d6bad5d97dacfda3947868aa8c14d0eaf220cc 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp index 6a6f19d71ef523697d69719de75b70bce3953e7e..135b9da6a00643c8e74bca6fdf4c36dbcc65f288 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -80,7 +80,8 @@ template + int D0sTransferSrcScalarPerVector = 4, + PipelineVersion PipelineVer = PipelineVersion::v1> struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle { static_assert(LoopSched == LoopScheduler::Default, @@ -621,13 +622,13 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId I1, // NBlockID - I1, // MRepeat - I1, // NRepeat - I1, // MWaveId - I1, // NWaveId - I1, // MPerXdl - I1, // NGroupNum - I1, // NInputNum + m0, // MRepeat + n0, // NRepeat + m1, // MWaveId + n1, // NWaveId + m2, // MPerXdl + n2, // NGroupNum + n3, // NInputNum n4)); // registerNum auto d0s_thread_buf = generate_tuple( @@ -644,9 +645,6 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle const auto wave_id = GetGemm0WaveIdx(); const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 - constexpr auto acc0_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, n2, n4)); - auto d0s_threadwise_copy = generate_tuple( [&](auto i) { using D0DataType = remove_cvref_t>; @@ -655,10 +653,19 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle D0DataType, decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]), decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), - Sequence, + Sequence, Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, 9, - n4, + D0sTransferSrcScalarPerVector, 1, false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], make_multi_index(block_work_idx[I0], // MBlockId @@ -884,62 +891,35 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle // multiple d if constexpr(NumD0Tensor) { - static_for<0, MXdlPerWave, 1>{}([&](auto mr) { - static_for<0, NXdlPerWave, 1>{}([&](auto nr) { - static_for<0, n2, 1>{}([&](auto groupid) { - static_for<0, NumD0Tensor, 1>{}([&](auto i) { - d0s_threadwise_copy(i).Run( - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], - d0s_grid_buf[i], - d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), - d0s_thread_buf(i)); - }); - - static_for<0, n4, 1>{}([&](auto i) { - constexpr index_t c_offset = acc0_thread_desc.CalculateOffset( - make_tuple(mr, nr, groupid, i)); - - // get reference to src data - const auto src_data_refs = generate_tie( - // return type should be lvalue - [&](auto iSrc) -> const auto& { - return d0s_thread_buf[iSrc][i]; - }, - Number{}); - - // get reference to dst data - auto dst_data_refs = generate_tie( - // return type should be lvalue - [&](auto) -> auto& { - return acc_thread_buf(Number{}); - }, - Number<2>{}); - - unpack2(c0de_element_op, dst_data_refs, src_data_refs); - }); - static_for<0, NumD0Tensor, 1>{}([&](auto i) { - d0s_threadwise_copy(i).MoveSrcSliceWindow( - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], - make_multi_index(0, 0, 0, 0, 0, 0, 0, 1, 0, 0)); - }); - }); - static_for<0, NumD0Tensor, 1>{}([&](auto i) { - d0s_threadwise_copy(i).MoveSrcSliceWindow( - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], - make_multi_index(0, 0, 0, 1, 0, 0, 0, -n2.value, 0, 0)); - }); - }); - static_for<0, NumD0Tensor, 1>{}([&](auto i) { - d0s_threadwise_copy(i).MoveSrcSliceWindow( - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], - make_multi_index(0, 0, 1, -NXdlPerWave, 0, 0, 0, 0, 0, 0)); - }); + static_assert(NXdlPerWave == n0); + static_assert(MXdlPerWave == m0); + + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + d0s_threadwise_copy(i).Run(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], + d0s_grid_buf[i], + d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + d0s_thread_buf(i)); + }); + static_for<0, m0 * n0 * n2 * n4, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { return d0s_thread_buf[iSrc][i]; }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto) -> auto& { return acc_thread_buf(i); }, + Number<2>{}); + + unpack2(c0de_element_op, dst_data_refs, src_data_refs); }); static_for<0, NumD0Tensor, 1>{}([&](auto i) { d0s_threadwise_copy(i).MoveSrcSliceWindow( d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], - make_multi_index(0, 1, -MXdlPerWave, 0, 0, 0, 0, 0, 0, 0)); + make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); }); } else diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index d6d2051113ff4fdebb74689293b00c532aae9c68..9eb2bf8aa81ea2689a4a8c71c2b78b070e701a4f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp index ede6a96dc9f53045b5fc029f4c4f4f6aa4fa4817..ed1ffdd85765e84feb2f6824a858ab842487c18c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp index 33c45a0f037b4c81fbaf810cc10e55f71a2ab254..b6c83af13a18639cf7f8312dd416dbf4d9cb5297 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_contraction_dlops_v1r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_contraction_dlops_v1r2.hpp deleted file mode 100644 index 2369f51795d9a3bcf5a28ad9de702d75e24fac48..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/grid/gridwise_contraction_dlops_v1r2.hpp +++ /dev/null @@ -1,662 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP -#define CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP - -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "blockwise_gemm_dlops_v2r3.hpp" -#include "blockwise_tensor_slice_transfer_v2.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "threadwise_tensor_slice_set.hpp" - -namespace ck { - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_contraction_dlops_v1r2( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_GK0_GM0_GM10_GM11_GK1 a_grid_desc_gk0_gm0_gm10_gm11_gk1, - const BGridDesc_GK0_GN0_GN10_GN11_GK1 b_grid_desc_gk0_gn0_gn10_gn11_gk1, - const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, - const CGridBlockCluster_BlockId_To_GM10_GN10 c_grid_block_cluster_blockid_to_gm10_gn10) -{ - constexpr index_t shared_block_size = - GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseContraction::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_grid_desc_gk0_gm0_gm10_gm11_gk1, - b_grid_desc_gk0_gn0_gn10_gn11_gk1, - c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, - c_grid_block_cluster_blockid_to_gm10_gn10, - integral_constant{}, - integral_constant{}); -} - -template -struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1 -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - // GM0 and GN0 need to known at compile-time - static constexpr auto GM0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I0); - static constexpr auto GN0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I2); - static constexpr auto GK1 = AGridDesc_GK0_GM0_GM1_GK1{}.GetLength(I3); - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // lds max alignment - // TODO: part of them should be moved into blockwise-gemm - // TODO: change this. I think it needs multi-dimensional alignment - constexpr auto max_lds_align = GK1; - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, GM0, I1, Number{}, GK1), - max_lds_align); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, GN0, I1, Number{}, GK1), - max_lds_align); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_aligned_space_size = math::integer_least_multiple( - a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_aligned_space_size = math::integer_least_multiple( - b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align); - - return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); - } - - __host__ __device__ static constexpr bool - CheckValidity(const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1, - const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1, - const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1) - { - static_assert(is_known_at_compile_time>::value && - is_known_at_compile_time>::value, - "wrong! GM0 and GN0 need to be known at compile-time"); - - const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2); - const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2); - const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0); - - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - - return ( - (GM0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I0) && - GM1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1) && - GN0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I2) && - GN1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3) && - GM0 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I1) && - GM1 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2) && - GN0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I1) && - GN1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2) && - GK0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0) && - GK1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I3)) && - (GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK0 % GK0PerBlock == 0)); - } - - __host__ __device__ static constexpr index_t - CalculateGridSize(const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1) - { - const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1); - const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3); - - constexpr index_t GM11 = GM1PerBlockGM11; - constexpr index_t GN11 = GN1PerBlockGN11; - - const index_t GM10 = GM1 / GM11; - const index_t GN10 = GN1 / GN11; - - const index_t grid_size = GM10 * GN10; - - return grid_size; - } - - __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t GK0) - { - const bool has_main_k_block_loop = (GK0 + GK0PerBlock) / (2 * GK0PerBlock) > 1; - - return has_main_k_block_loop; - } - - __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t GK0) - { - const bool has_double_tail_k_block_loop = (GK0 / GK0PerBlock) % 2 == 0; - - return has_double_tail_k_block_loop; - } - - __host__ __device__ static constexpr auto MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1( - const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1) - { - const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0); - const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2); - - const auto GM11 = Number{}; - const auto GM10 = GM1 / GM11; - - const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = transform_tensor_descriptor( - a_grid_desc_gk0_gm0_gm1_gk1, - make_tuple(make_pass_through_transform(GK0), - make_pass_through_transform(GM0), - make_unmerge_transform(make_tuple(GM10, GM11)), - make_pass_through_transform(GK1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); - - return a_grid_desc_gk0_gm0_gm10_gm11_gk1; - } - - __host__ __device__ static constexpr auto MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1( - const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1) - { - const auto GK0 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0); - const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2); - - const auto GN11 = Number{}; - const auto GN10 = GN1 / GN11; - - const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = transform_tensor_descriptor( - b_grid_desc_gk0_gn0_gn1_gk1, - make_tuple(make_pass_through_transform(GK0), - make_pass_through_transform(GN0), - make_unmerge_transform(make_tuple(GN10, GN11)), - make_pass_through_transform(GK1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); - - return b_grid_desc_gk0_gn0_gn10_gn11_gk1; - } - - __host__ __device__ static constexpr auto MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1( - const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1) - { - const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1); - const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3); - - constexpr auto GM11 = Number{}; - constexpr auto GN11 = Number{}; - - const auto GM10 = GM1 / GM11; - const auto GN10 = GN1 / GN11; - - constexpr auto BM = GM0 * GM11; - constexpr auto BN = GN0 * GN11; - - constexpr auto BM1 = - Number{}; - constexpr auto BN1 = - Number{}; - - constexpr auto BM0 = BM / BM1; - constexpr auto BN0 = BN / BN1; - - const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_tensor_descriptor( - c_grid_desc_gm0_gm1_gn0_gn1, - make_tuple(make_pass_through_transform(GM0), - make_unmerge_transform(make_tuple(GM10, GM11)), - make_pass_through_transform(GN0), - make_unmerge_transform(make_tuple(GN10, GN11))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{})); - - const auto c_gm10_bm_gn10_bn_grid_desc = transform_tensor_descriptor( - c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc, - make_tuple(make_pass_through_transform(GM10), - make_merge_transform(make_tuple(GM0, GM11)), - make_pass_through_transform(GN10), - make_merge_transform(make_tuple(GN0, GN11))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = transform_tensor_descriptor( - c_gm10_bm_gn10_bn_grid_desc, - make_tuple(make_pass_through_transform(GM10), - make_unmerge_transform(make_tuple(BM0, BM1)), - make_pass_through_transform(GN10), - make_unmerge_transform(make_tuple(BN0, BN1))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{})); - - return c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1; - } - - __host__ __device__ static constexpr auto MakeCGridBlockCluster_BlockId_To_GM10_GN10( - const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1) - { - const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1); - const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3); - - constexpr auto GM11 = Number{}; - constexpr auto GN11 = Number{}; - - const auto GM10 = GM1 / GM11; - const auto GN10 = GN1 / GN11; - - const auto c_grid_block_cluster_blockid_to_gm10_gn10 = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(GM10, GN10))), - make_tuple(Sequence<0, 1>{}), - make_tuple(Sequence<0>{})); - - return c_grid_block_cluster_blockid_to_gm10_gn10; - } - - using AGridDesc_GK0_GM0_GM10_GM11_GK1 = - decltype(MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(AGridDesc_GK0_GM0_GM1_GK1{})); - using BGridDesc_GK0_GN0_GN10_GN11_GK1 = - decltype(MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(BGridDesc_GK0_GN0_GN1_GK1{})); - using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = - decltype(MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(CGridDesc_GM0_GM1_GN0_GN1{})); - using CGridBlockCluster_BlockId_To_GM10_GN10 = - decltype(MakeCGridBlockCluster_BlockId_To_GM10_GN10(CGridDesc_GM0_GM1_GN0_GN1{})); - - template - __device__ static void - Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - FloatAB* __restrict__ p_shared_block, - const AGridDesc_GK0_GM0_GM10_GM11_GK1& a_grid_desc_gk0_gm0_gm10_gm11_gk1, - const BGridDesc_GK0_GN0_GN10_GN11_GK1& b_grid_desc_gk0_gn0_gn10_gn11_gk1, - const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1& c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, - const CGridBlockCluster_BlockId_To_GM10_GN10& c_grid_block_cluster_blockid_to_gm10_gn10, - integral_constant, - integral_constant) - { - const auto a_global_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize()); - - const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0); - - // divide block work by [GM10, GN10] - const auto c_gm10_gn10_block_cluster_idx = - c_grid_block_cluster_blockid_to_gm10_gn10.CalculateBottomIndex( - make_multi_index(get_block_1d_id())); - - // HACK: this force index data into SGPR - const index_t igm10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I0]); - const index_t ign10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I1]); - - // lds max alignment - // TODO: part of them should be moved into blockwise-gemm - // TODO: change this. I think it needs multi-dimensional alignment - constexpr auto max_lds_align = GK1; - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, GM0, I1, Number{}, GK1), - max_lds_align); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, GN0, I1, Number{}, GK1), - max_lds_align); - - // A matrix in LDS memory for blockwise GEMM - // be careful of LDS alignment - constexpr auto a_block_desc_gk0_bm_gk1 = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, GM0 * Number{}, GK1), max_lds_align); - - // B matrix in LDS memory for blockwise GEMM - // be careful of LDS alignment - constexpr auto b_block_desc_gk0_bn_gk1 = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, GN0 * Number{}, GK1), max_lds_align); - - static_assert(a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize() == - a_block_desc_gk0_bm_gk1.GetElementSpaceSize() && - b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize() == - b_block_desc_gk0_bn_gk1.GetElementSpaceSize(), - "wrong!"); - - // A matrix blockwise copy - auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< - BlockSize, - InMemoryDataOperationEnum::Set, - Sequence, - ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1), - decltype(a_block_desc_gk0_gm0_gm10_gm11_gk1), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2, 3, 4>, - ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // SrcVectorTensorLengths - ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // DstVectorTensorLengths - ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder - Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder - false, - true>(a_grid_desc_gk0_gm0_gm10_gm11_gk1, - make_multi_index(0, 0, igm10, 0, 0), - a_block_desc_gk0_gm0_gm10_gm11_gk1, - make_multi_index(0, 0, 0, 0, 0)); - - // B matrix blockwise copy - auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< - BlockSize, - InMemoryDataOperationEnum::Set, - Sequence, - BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1), - decltype(b_block_desc_gk0_gn0_gn10_gn11_gk1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2, 3, 4>, - BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // SrcVectorTensorLengths - BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // DstVectorTensorLengths - BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder - Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder - false, - true>(b_grid_desc_gk0_gn0_gn10_gn11_gk1, - make_multi_index(0, 0, ign10, 0, 0), - b_block_desc_gk0_gn0_gn10_gn11_gk1, - make_multi_index(0, 0, 0, 0, 0)); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[GK0PerBlock, GM1PerBlockGM11] is in LDS - // b_mtx[KPerBlocl, GN1PerBlockGN11] is in LDS - // c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in - // register - const auto blockwise_gemm = - BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< - BlockSize, - FloatAB, - FloatAB, - FloatAcc, - decltype(a_block_desc_gk0_bm_gk1), - decltype(b_block_desc_gk0_bn_gk1), - BM1PerThreadBM11, - BN1PerThreadBN11, - BK0PerThread, - BM10BN10ThreadClusterBM10Xs, - BM10BN10ThreadClusterBN10Xs, - BM1PerThreadBM11, - BN1PerThreadBN11>{}; - - constexpr auto c_thread_tensor_lengths_bm0_bm1_bn0_bn1 = - decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); - - constexpr auto c_thread_desc_bm0_bm1_bn0_bn1 = make_naive_tensor_descriptor_packed( - sequence_to_tuple_of_number(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_aligned_space_size = math::integer_least_multiple( - a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_aligned_space_size = math::integer_least_multiple( - b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align); - - FloatAB* p_a_block_double = p_shared_block; - FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; - - // register allocation for output - auto c_thread_buf = make_static_buffer( - c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize()); - - ThreadwiseTensorSliceSet_v1{} - .Run(c_thread_desc_bm0_bm1_bn0_bn1, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - FloatAcc{0}); - - constexpr auto a_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0); - - auto a_block_even_buf = make_dynamic_buffer( - p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); - auto b_block_even_buf = make_dynamic_buffer( - p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); - - auto a_block_odd_buf = make_dynamic_buffer( - p_a_block_double + a_block_aligned_space_size, - a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); - auto b_block_odd_buf = make_dynamic_buffer( - p_b_block_double + b_block_aligned_space_size, - b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); - - // LDS double buffer: preload data into LDS - { - a_blockwise_copy.RunRead( - a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead( - b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{}); - - a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf); - b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf); - } - - if constexpr(HasMainKBlockLoop) - { - index_t gk0_block_on_grid = 0; - - // LDS double buffer: main body - // use Do-While loop instead of For loop to simplify control flow - do - { - // even iteration - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, - a_block_slice_copy_step, - AGridMoveSliceWindowStepHacks{}); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, - b_block_slice_copy_step, - BGridMoveSliceWindowStepHacks{}); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead( - a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead( - b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{}); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1, - a_block_even_buf, - b_block_even_buf, - c_thread_buf); - - // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf); - - // odd iteration - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, - a_block_slice_copy_step, - AGridMoveSliceWindowStepHacks{}); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, - b_block_slice_copy_step, - BGridMoveSliceWindowStepHacks{}); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead( - a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead( - b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{}); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run( - c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf); - - // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf); - b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf); - - gk0_block_on_grid += 2 * GK0PerBlock; - } while(gk0_block_on_grid < GK0 - 2 * GK0PerBlock); - } - - // LDS double buffer: tail - if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left - { - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, - a_block_slice_copy_step, - AGridMoveSliceWindowStepHacks{}); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, - b_block_slice_copy_step, - BGridMoveSliceWindowStepHacks{}); - - __syncthreads(); - - // LDS double buffer: load last data from device mem - a_blockwise_copy.RunRead( - a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead( - b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{}); - - // LDS double buffer: GEMM on 2nd-last data - blockwise_gemm.Run( - c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf); - - // LDS double buffer: store last data to LDS - a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf); - - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run( - c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf); - } - else // if has 1 iteration left - { - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run( - c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf); - } - - // output: register to global memory - { - constexpr auto c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1 = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - Number{}, - I1, - Number{}, - Number{})); - - const auto c_thread_origin_on_block_bm0_bm1_bn0_bn1 = - blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( - get_thread_local_1d_id()); - - ThreadwiseTensorSliceTransfer_v1r3< - FloatAcc, - FloatC, - decltype(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1), - decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1), - Sequence<1, - c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0], - c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1], - 1, - c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I2], - c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I3]>, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - false>{c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, - make_multi_index(igm10, - c_thread_origin_on_block_bm0_bm1_bn0_bn1[I0], - c_thread_origin_on_block_bm0_bm1_bn0_bn1[I1], - ign10, - c_thread_origin_on_block_bm0_bm1_bn0_bn1[I2], - c_thread_origin_on_block_bm0_bm1_bn0_bn1[I3])} - .Run(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1, - make_tuple(I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, - c_grid_buf, - CGridStepHacks{}); - } - } -}; - -} // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp index 8b82b65540d1c5612002549f72708be07b649dae..d686c14b350953a48bf132694251ac677fb47c4f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp index 05257d16275e5f9817ec5d9cbbd06b133f0b5e05..bf0e8c186c75cef05b08c7ef9e1bf16c016830a3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp index b09a73590239609908265cd6b72e3e5bb26127aa..3ea72b85345869a938e52f872c2c586c6e36170f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp index 16ba23280ddcdf9c3f1d87cb2e4cd64a72691650..533559adf173813a239b522cc5214889ceae07b9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -66,7 +66,8 @@ __global__ void const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp index 9c68b4f5c3ffbb13cd44701121fca7cc113d76cd..27f48a84ba72fe284a964ac295540eac1c4f6766 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r2.hpp deleted file mode 100644 index 84e033e1e91bb9873ebed66906c6da5a287bf1bd..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r2.hpp +++ /dev/null @@ -1,608 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP -#define CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP - -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "blockwise_gemm_dlops_v2r2.hpp" -#include "blockwise_tensor_slice_transfer.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "threadwise_tensor_slice_set.hpp" - -namespace ck { - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_dlops_v1r2( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AKM0M1GridDesc a_k_m0_m1_grid_desc, - const BKN0N1GridDesc b_k_n0_n1_grid_desc, - const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc, - const CBlockIdToM0N0BlockClusterAdaptor cblockid_to_m0_n0_block_cluster_adaptor) -{ - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_k_m0_m1_grid_desc, - b_k_n0_n1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - cblockid_to_m0_n0_block_cluster_adaptor, - integral_constant{}, - integral_constant{}); -} - -template -struct GridwiseGemmDlops_km_kn_mn_v1r2 -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - constexpr auto max_lds_align = math::lcm(Number{}, - Number{}, - Number{}, - Number{}); - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}), max_lds_align); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}), max_lds_align); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_aligned_space_size = - math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_aligned_space_size = - math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); - - return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); - } - - __host__ __device__ static constexpr bool CheckValidity(const AKMGridDesc& a_k_m_grid_desc, - const BKNGridDesc& b_k_n_grid_desc, - const CMNGridDesc& c_m_n_grid_desc) - { - const auto M = a_k_m_grid_desc.GetLength(I1); - const auto N = b_k_n_grid_desc.GetLength(I1); - const auto K = a_k_m_grid_desc.GetLength(I0); - - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - - return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && - K == b_k_n_grid_desc.GetLength(I0)) && - (M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K % KPerBlock == 0); - } - - __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) - { - const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1); - - return grid_size; - } - - __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) - { - const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1; - - return has_main_k_block_loop; - } - - __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K) - { - const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0; - - return has_double_tail_k_block_loop; - } - - __host__ __device__ static constexpr auto - MakeAKM0M1GridDescriptor(const AKMGridDesc& a_k_m_grid_desc) - { - const auto K = a_k_m_grid_desc.GetLength(I0); - const auto M = a_k_m_grid_desc.GetLength(I1); - - const auto M1 = Number{}; - const auto M0 = M / M1; - - const auto a_k_m0_m1_grid_desc = transform_tensor_descriptor( - a_k_m_grid_desc, - make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(M0, M1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{})); - - return a_k_m0_m1_grid_desc; - } - - __host__ __device__ static constexpr auto - MakeBKN0N1GridDescriptor(const BKNGridDesc& b_k_n_grid_desc) - { - const auto K = b_k_n_grid_desc.GetLength(I0); - const auto N = b_k_n_grid_desc.GetLength(I1); - - const auto N1 = Number{}; - const auto N0 = N / N1; - - const auto b_k_n0_n1_grid_desc = transform_tensor_descriptor( - b_k_n_grid_desc, - make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(N0, N1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{})); - - return b_k_n0_n1_grid_desc; - } - - __host__ __device__ static constexpr auto - MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) - { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - constexpr auto M11 = - Number{}; - constexpr auto N11 = - Number{}; - - constexpr auto M10 = M1 / M11; - constexpr auto N10 = N1 / N11; - - const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor( - c_m_n_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), - make_unmerge_transform(make_tuple(N0, N10, N11))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); - - return c_m0_m10_m11_n0_n10_n11_grid_desc; - } - - __host__ __device__ static constexpr auto - MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) - { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto cblockid_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))), - make_tuple(Sequence<0, 1>{}), - make_tuple(Sequence<0>{})); - - return cblockid_to_m0_n0_block_cluster_adaptor; - } - - using AKM0M1GridDesc = decltype(MakeAKM0M1GridDescriptor(AKMGridDesc{})); - using BKN0N1GridDesc = decltype(MakeBKN0N1GridDescriptor(BKNGridDesc{})); - using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{})); - using CBlockIdToM0N0BlockClusterAdaptor = - decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{})); - - template - __device__ static void - Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - FloatAB* __restrict__ p_shared_block, - const AKM0M1GridDesc& a_k_m0_m1_grid_desc, - const BKN0N1GridDesc& b_k_n0_n1_grid_desc, - const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc, - const CBlockIdToM0N0BlockClusterAdaptor& cblockid_to_m0_n0_block_cluster_adaptor, - integral_constant, - integral_constant) - { - const auto a_global_buf = make_dynamic_buffer( - p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( - p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); - - const auto K = a_k_m0_m1_grid_desc.GetLength(I0); - - // divide block work by [M, N] - const auto c_m0_n0_block_cluster_idx = - cblockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex( - make_multi_index(get_block_1d_id())); - - // HACK: this force index data into SGPR - const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); - const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(Number{}, - Number{}, - Number{}, - Number{}); - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}), max_lds_align); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}), max_lds_align); - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k_m0_m1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, I1, Number{}), max_lds_align); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k_n0_n1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, I1, Number{}), max_lds_align); - - // A matrix blockwise copy - auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4, - ABlockTransferThreadSliceLengths_K_M0_M1, - ABlockTransferThreadClusterLengths_K_M0_M1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_k_m0_m1_grid_desc), - decltype(a_k_m0_m1_block_desc), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_M1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>(a_k_m0_m1_grid_desc, - make_multi_index(0, im0, 0), - a_k_m0_m1_block_desc, - make_multi_index(0, 0, 0)); - - // B matrix blockwise copy - auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4, - BBlockTransferThreadSliceLengths_K_N0_N1, - BBlockTransferThreadClusterLengths_K_N0_N1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_k_n0_n1_grid_desc), - decltype(b_k_n0_n1_block_desc), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_N1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>(b_k_n0_n1_grid_desc, - make_multi_index(0, in0, 0), - b_k_n0_n1_block_desc, - make_multi_index(0, 0, 0)); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[KPerBlock, MPerBlockM1] is in LDS - // b_mtx[KPerBlocl, NPerBlockN1] is in LDS - // c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in - // register - const auto blockwise_gemm = - BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2{}; - constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = - decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths(); - - constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed( - sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_aligned_space_size = - math::integer_least_multiple(a_k_m0_m1_block_desc.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_aligned_space_size = - math::integer_least_multiple(b_k_n0_n1_block_desc.GetElementSpaceSize(), max_lds_align); - - FloatAB* p_a_block_double = p_shared_block; - FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; - - // register allocation for output - auto c_thread_buf = make_static_buffer( - c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); - - ThreadwiseTensorSliceSet_v1{} - .Run(c_m10_m11_n10_n11_thread_desc, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - FloatAcc{0}); - - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); - - // hack to control index calculation when iterating over A and B matrix for threadwise copy - constexpr auto a_k_m0_m1_global_step_hacks = AGridStepHacks{}; - constexpr auto b_k_n0_n1_global_step_hacks = BGridStepHacks{}; - - // hack to control index calculation when move slice window for A and B matrix for - // threadwise copy - constexpr auto a_k_m0_m1_global_move_slice_window_step_hack = - AGridMoveSliceWindowStepHacks{}; - constexpr auto b_k_n0_n1_global_move_slice_window_step_hack = - BGridMoveSliceWindowStepHacks{}; - - auto a_block_even_buf = make_dynamic_buffer( - p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize()); - auto b_block_even_buf = make_dynamic_buffer( - p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize()); - - auto a_block_odd_buf = make_dynamic_buffer( - p_a_block_double + a_block_aligned_space_size, - a_k_m0_m1_block_desc.GetElementSpaceSize()); - auto b_block_odd_buf = make_dynamic_buffer( - p_b_block_double + b_block_aligned_space_size, - b_k_n0_n1_block_desc.GetElementSpaceSize()); - - // LDS double buffer: preload data into LDS - { - a_blockwise_copy.RunRead( - a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks); - b_blockwise_copy.RunRead( - b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks); - - a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf); - b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf); - } - - if constexpr(HasMainKBlockLoop) - { - index_t k_block_data_begin = 0; - - // LDS double buffer: main body - // use Do-While loop instead of For loop to simplify control flow - do - { - // even iteration - a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc, - a_block_slice_copy_step, - a_k_m0_m1_global_move_slice_window_step_hack); - b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc, - b_block_slice_copy_step, - b_k_n0_n1_global_move_slice_window_step_hack); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead( - a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks); - b_blockwise_copy.RunRead( - b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, - a_block_even_buf, - b_block_even_buf, - c_thread_buf); - - // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf); - - // odd iteration - a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc, - a_block_slice_copy_step, - a_k_m0_m1_global_move_slice_window_step_hack); - b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc, - b_block_slice_copy_step, - b_k_n0_n1_global_move_slice_window_step_hack); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead( - a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks); - b_blockwise_copy.RunRead( - b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); - - // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf); - b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf); - - k_block_data_begin += 2 * KPerBlock; - } while(k_block_data_begin < K - 2 * KPerBlock); - } - - // LDS double buffer: tail - if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left - { - a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc, - a_block_slice_copy_step, - a_k_m0_m1_global_move_slice_window_step_hack); - b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc, - b_block_slice_copy_step, - b_k_n0_n1_global_move_slice_window_step_hack); - - __syncthreads(); - - // LDS double buffer: load last data from device mem - a_blockwise_copy.RunRead( - a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks); - b_blockwise_copy.RunRead( - b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks); - - // LDS double buffer: GEMM on 2nd-last data - blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); - - // LDS double buffer: store last data to LDS - a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf); - - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); - } - else // if has 1 iteration left - { - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); - } - - // output: register to global memory - { - constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - Number{}, - I1, - Number{}, - Number{})); - - const auto c_m10_m11_n10_n11_thread_origin_idx_on_block = - blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id()); - - ThreadwiseTensorSliceTransfer_v1r3< - FloatAcc, - FloatC, - decltype(c_m0_m10_m11_n0_n10_n11_thread_desc), - decltype(c_m0_m10_m11_n0_n10_n11_grid_desc), - Sequence<1, - c_m10_m11_n10_n11_thread_tensor_lengths[I0], - c_m10_m11_n10_n11_thread_tensor_lengths[I1], - 1, - c_m10_m11_n10_n11_thread_tensor_lengths[I2], - c_m10_m11_n10_n11_thread_tensor_lengths[I3]>, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>{c_m0_m10_m11_n0_n10_n11_grid_desc, - make_multi_index(im0, - c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], - c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], - in0, - c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], - c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])} - .Run(c_m0_m10_m11_n0_n10_n11_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_grid_buf, - CGridStepHacks{}); - } - } -}; - -} // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp deleted file mode 100644 index b1dfb0c73fc92a19662e01df31090008cdef87ff..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp +++ /dev/null @@ -1,461 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_GRIDWISE_GEMM_V2_HPP -#define CK_GRIDWISE_GEMM_V2_HPP - -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "blockwise_tensor_slice_transfer.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "blockwise_gemm_dlops_v3.hpp" - -namespace ck { - -template -struct GridwiseGemmDlops_km_kn_mn_v3 -{ - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - constexpr auto E = EPerBlock * 3 * 3; - - constexpr auto max_lds_align = - math::lcm(Number{}, Number{}); - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}), max_lds_align); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_e_k_desc.GetElementSpaceSize(), max_lds_align); - - return a_block_space_size * sizeof(FloatAB); - } - - template - __device__ void Run(const AGlobalDesc& a_e_k_global_desc, - const FloatAB* __restrict__ p_a_global, - const BGlobalDesc& b_e_n_ho_wo_global_desc, - const FloatAB* __restrict__ p_b_global, - const CGlobalDesc& c_k_n_ho_wo_global_desc, - FloatC* __restrict__ p_c_global, - FloatAB* __restrict__ p_shared_block, - integral_constant, - integral_constant) const - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto a_global_buf = make_dynamic_buffer( - p_a_global, a_e_k_global_desc.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( - p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize()); - auto c_global_buf = make_dynamic_buffer( - p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize()); - - constexpr auto E = EPerBlock * 3 * 3; - - // const auto E = a_e_k_global_desc.GetLength(I0); - const auto K = a_e_k_global_desc.GetLength(I1); - - const auto N = b_e_n_ho_wo_global_desc.GetLength(I1); - const auto Ho = b_e_n_ho_wo_global_desc.GetLength(I2); - const auto Wo = b_e_n_ho_wo_global_desc.GetLength(I3); - -// divide block work by [M, N] -#if 0 - const auto ho_block_work_num = Ho / Number{}; - const auto wo_block_work_num = Wo / Number{}; - const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num; - - const index_t k_block_work_id = get_block_1d_id() / hwo_block_work_num; - const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num; - - const index_t ho_block_work_id = hwo_block_work_id / wo_block_work_num; - const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num; -#else - // Hack: this force result into SGPR - const index_t ho_block_work_num = __builtin_amdgcn_readfirstlane(Ho / HoPerBlock); - const index_t wo_block_work_num = __builtin_amdgcn_readfirstlane(Wo / WoPerBlock); - const index_t hwo_block_work_num = ho_block_work_num * wo_block_work_num; - - const index_t k_block_work_id = - __builtin_amdgcn_readfirstlane(get_block_1d_id() / hwo_block_work_num); - const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num; - - const index_t ho_block_work_id = - __builtin_amdgcn_readfirstlane(hwo_block_work_id / wo_block_work_num); - const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num; -#endif - - // lds max alignment - constexpr auto max_lds_align = - math::lcm(Number{}, Number{}); - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_e_k_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}), max_lds_align); - - constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}), max_lds_align); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_e_n_ho_wo_block_desc = make_naive_tensor_descriptor_packed(make_tuple( - Number{}, Number<1>{}, Number{}, Number{})); - - // c_thread_mtx definition: this is a mess - // TODO:: more elegent way of defining c_thread_mtx - constexpr auto c_k_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( - Number{}, Number<1>{}, Number{}, Number{})); - - auto blockwise_gemm = - BlockwiseGemmDlops_km_kn_m0m1n0n1_v3{}; - - auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const auto k_thread_id = c_thread_mtx_index.k; - const auto ho_thread_id = c_thread_mtx_index.h; - const auto wo_thread_id = c_thread_mtx_index.w; - - const index_t k_block_data_on_global = k_block_work_id * KPerBlock; - const index_t ho_block_data_on_global = ho_block_work_id * HoPerBlock; - const index_t wo_block_data_on_global = wo_block_work_id * WoPerBlock; - - const index_t ho_thread_data_on_global = - ho_block_data_on_global + ho_thread_id * HoPerThread; - const index_t wo_thread_data_on_global = - wo_block_data_on_global + wo_thread_id * WoPerThread; - - // A matrix blockwise copy - auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4, - ABlockTransferThreadSliceLengths_E_K, - ABlockTransferThreadClusterLengths_E_K, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_e_k_global_desc), - decltype(a_e_k_desc), - ABlockTransferSrcAccessOrder, - Sequence<0, 1>, - ABlockTransferSrcVectorDim, - 1, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>(a_e_k_global_desc, - make_multi_index(0, k_block_data_on_global), - a_e_k_desc, - make_multi_index(0, 0)); - - constexpr auto b_e_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( - Number{}, Number<1>{}, Number{}, Number{})); - - auto b_threadwise_transfer = - ThreadwiseTensorSliceTransfer_v2, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - 1, - true>( - b_e_n_ho_wo_global_desc, - make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global)); - - auto a_block_buf = make_dynamic_buffer( - p_shared_block, a_e_k_desc.GetElementSpaceSize()); - - // register allocation for output - StaticBuffer - c_thread_buf; - - // initialize output thread tensor - ThreadwiseTensorSliceSet_v1>{} - .Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); - - constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0); - - // hack to control index calculation when iterating over A and B matrix for threadwise copy - constexpr auto a_e_k_global_step_hacks = AGlobalStepHacks{}; - constexpr auto b_e_n_ho_wo_global_step_hacks = BGlobalStepHacks{}; - - // hack to control index calculation when move slice window for A and B matrix for - // threadwise copy - constexpr auto a_e_k_global_move_slice_window_step_hack = AGlobalMoveSliceWindowStepHacks{}; - constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack = - BGlobalMoveSliceWindowStepHacks{}; - - // double regsiter buffer for b - StaticBuffer - b_thread_even_buf, b_thread_odd_buf; - - // LDS double buffer: preload data - { - a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_step_hacks); - - b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, - b_global_buf, - b_e_n_ho_wo_thread_desc, - make_tuple(I0, I0, I0, I0), - b_thread_even_buf, - b_e_n_ho_wo_global_step_hacks); - - a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf); - } - - __syncthreads(); - - if constexpr(HasMainKBlockLoop) - { - index_t e_block_data_begin = 0; - - // LDS double buffer: main body - // use Do-While loop instead of For loop to simplify control flow - do - { - // even iteration - b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc, - b_thread_slice_copy_step); - - b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, - b_global_buf, - b_e_n_ho_wo_thread_desc, - make_tuple(I0, I0, I0, I0), - b_thread_odd_buf, - b_e_n_ho_wo_global_step_hacks); - - // LDS double buffer: GEMM on current data - // TODO: @Zhang Jing: blockwise gemm should be able to move slice window - blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); - - blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0)); - - b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc, - b_thread_slice_copy_step); - - b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, - b_global_buf, - b_e_n_ho_wo_thread_desc, - make_tuple(I0, I0, I0, I0), - b_thread_even_buf, - b_e_n_ho_wo_global_step_hacks); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); - - blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0)); - - e_block_data_begin += 2 * EPerBlock; - - } while(e_block_data_begin < E - 2 * EPerBlock); - } - - // LDS double buffer: tail - if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left - { - b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc, - b_thread_slice_copy_step); - - b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, - b_global_buf, - b_e_n_ho_wo_thread_desc, - make_tuple(I0, I0, I0, I0), - b_thread_odd_buf, - b_e_n_ho_wo_global_step_hacks); - - // LDS double buffer: GEMM on 2nd-last data - blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); - - blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0)); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); - } - else // if has 1 iteration left - { - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); - } - - // output: register to global memory - { - // hack to control index calculation when iterating over c_k_n_ho_wo_global tensor - constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = CGlobalStepHacks{}; - - const index_t k_thread_data_on_global = - k_block_data_on_global + k_thread_id * KPerThread; - - ThreadwiseTensorSliceTransfer_v1r3, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>( - c_k_n_ho_wo_global_desc, - make_multi_index( - k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global)) - .Run(c_k_n_ho_wo_thread_desc, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - c_k_n_ho_wo_global_desc, - c_global_buf, - c_k_n_ho_wo_global_tensor_step_hacks); - } - } - - // pass tensor descriptor by reference - template - __device__ void Run(const AGlobalDesc& a_e_k_global_desc, - const FloatAB* __restrict__ p_a_global, - const BGlobalDesc& b_e_n_ho_wo_global_desc, - const FloatAB* __restrict__ p_b_global, - const CGlobalDesc& c_k_n_ho_wo_global_desc, - FloatC* __restrict__ p_c_global, - integral_constant, - integral_constant) const - { - constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - Run(a_e_k_global_desc, - p_a_global, - b_e_n_ho_wo_global_desc, - p_b_global, - c_k_n_ho_wo_global_desc, - p_c_global, - p_shared_block, - integral_constant{}, - integral_constant{}); - } - - // pass tensor descriptors by their pointers - template - __device__ void Run(const AGlobalDesc* p_a_e_k_global_desc, - const FloatAB* __restrict__ p_a_global, - const BGlobalDesc* p_b_e_n_ho_wo_global_desc, - const FloatAB* __restrict__ p_b_global, - const CGlobalDesc* p_c_k_n_ho_wo_global_desc, - FloatC* __restrict__ p_c_global, - integral_constant, - integral_constant) const - { - const auto a_e_k_global_desc = *p_a_e_k_global_desc; - const auto b_e_n_ho_wo_global_desc = *p_b_e_n_ho_wo_global_desc; - const auto c_k_n_ho_wo_global_desc = *p_c_k_n_ho_wo_global_desc; - - Run(a_e_k_global_desc, - p_a_global, - b_e_n_ho_wo_global_desc, - p_b_global, - c_k_n_ho_wo_global_desc, - p_c_global, - integral_constant{}, - integral_constant{}); - } - - // pass tensor descriptors by void* - template - __device__ void Run(const void* p_a_e_k_global_desc, - const FloatAB* __restrict__ p_a_global, - const void* p_b_e_n_ho_wo_global_desc, - const FloatAB* __restrict__ p_b_global, - const void* p_c_k_n_ho_wo_global_desc, - FloatC* __restrict__ p_c_global, - integral_constant, - integral_constant) const - { - const auto a_e_k_global_desc = *reinterpret_cast(p_a_e_k_global_desc); - const auto b_e_n_ho_wo_global_desc = - *reinterpret_cast(p_b_e_n_ho_wo_global_desc); - const auto c_k_n_ho_wo_global_desc = - *reinterpret_cast(p_c_k_n_ho_wo_global_desc); - - Run(a_e_k_global_desc, - p_a_global, - b_e_n_ho_wo_global_desc, - p_b_global, - c_k_n_ho_wo_global_desc, - p_c_global, - integral_constant{}, - integral_constant{}); - } -}; - -} // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp deleted file mode 100644 index ace844338414668030e5dd6aeeaf559e1a7c1060..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp +++ /dev/null @@ -1,1597 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_GRIDWISE_GEMM_V3_HPP -#define CK_GRIDWISE_GEMM_V3_HPP - -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "blockwise_tensor_slice_transfer.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "threadwise_tensor_slice_set.hpp" -#include "blockwise_gemm_dlops_v3.hpp" - -namespace ck { - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_dlops_v3( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - const FloatC* __restrict__ p_bias_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc, - const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - const CBlockIdToBlockClusterAdaptor_K_N_H_W cblockid_to_k_n_h_w_block_cluster_adaptor) -{ - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::ConvBiasActiv(p_a_grid, - p_b_grid, - p_bias_grid, - p_c_grid, - p_shared_block, - a_e0_e1_k0_k1_e2_grid_desc, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - cblockid_to_k_n_h_w_block_cluster_adaptor, - integral_constant{}, - integral_constant{}); -} - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_dlops_v3_resize_add( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - const FloatC* __restrict__ p_bias_grid, - FloatC* __restrict__ p_d_grid, - const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc, - const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, - const CBlockIdToBlockClusterAdaptor_K_N_H_W cblockid_to_k_n_h_w_block_cluster_adaptor) -{ - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::ConvBiasActivResizeAdd(p_a_grid, - p_b_grid, - p_bias_grid, - p_d_grid, - p_shared_block, - a_e0_e1_k0_k1_e2_grid_desc, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, - cblockid_to_k_n_h_w_block_cluster_adaptor, - integral_constant{}, - integral_constant{}); -} - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_dlops_v3_maxpool( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - const FloatC* __restrict__ p_bias_grid, - FloatC* __restrict__ p_c_grid, - FloatC* __restrict__ p_d_grid, - const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc, - const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, - const CBlockIdToBlockClusterAdaptor_K_N_H_W cblockid_to_k_n_h_w_block_cluster_adaptor) -{ - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::ConvBiasActivMaxpool(p_a_grid, - p_b_grid, - p_bias_grid, - p_c_grid, - p_d_grid, - p_shared_block, - a_e0_e1_k0_k1_e2_grid_desc, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, - cblockid_to_k_n_h_w_block_cluster_adaptor, - integral_constant{}, - integral_constant{}); -} - -template -struct GridwiseGemmDlops_km_kn_mn_v3 -{ - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - - static constexpr auto E1 = Number{}; - static constexpr auto E2 = Number{}; - static constexpr auto K2 = Number{}; - - static constexpr auto NPerBlock = I1; - - static constexpr FloatAcc alpha = 0.3; - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - constexpr auto max_lds_align = Number{}; - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_e0_e1_k1_e2_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(I1, Number{}, Number{}, Number{}), max_lds_align); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = math::integer_least_multiple( - a_e0_e1_k1_e2_block_desc.GetElementSpaceSize(), max_lds_align); - - return a_block_space_size * sizeof(FloatAB); - } - - __host__ __device__ static constexpr index_t - CalculateGridSize(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc) - { - const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0); - const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1); - const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2); - const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3); - - const auto K0 = K / KPerBlock; - const auto N0 = N / NPerBlock; - const auto H0 = Ho / HoPerBlock; - const auto W0 = Wo / WoPerBlock; - - const index_t grid_size = K0 * N0 * H0 * W0; - - return grid_size; - } - - __host__ __device__ static constexpr bool CalculateHasMainE0BlockLoop(const index_t E0) - { - const bool has_main_e0_block_loop = E0 > 1; - - return has_main_e0_block_loop; - } - - __host__ __device__ static constexpr bool CalculateHasMainE1BlockLoop() - { - const bool has_main_e1_block_loop = ((E1 + E1PerBlock) / (2 * E1PerBlock)) > 1; - - return has_main_e1_block_loop; - } - - __host__ __device__ static constexpr bool CalculateHasDoubleTailE1BlockLoop() - { - const bool has_double_tail_e1_block_loop = (E1 / E1PerBlock) % 2 == 0; - - return has_double_tail_e1_block_loop; - } - - __host__ __device__ static constexpr auto - MakeAE0E1K0K1E2GridDescriptor(const AGridDesc_E0_E1_K_E2& a_e0_e1_k_e2_grid_desc) - { - const auto E0 = a_e0_e1_k_e2_grid_desc.GetLength(I0); - const auto K = a_e0_e1_k_e2_grid_desc.GetLength(I2); - - const auto K1 = Number{}; - const auto K0 = K / K1; - - const auto a_e0_e1_k0_k1_e2_grid_desc = transform_tensor_descriptor( - a_e0_e1_k_e2_grid_desc, - make_tuple(make_pass_through_transform(E0), - make_pass_through_transform(E1), - make_unmerge_transform(make_tuple(K0, K1)), - make_pass_through_transform(E2)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); - - return a_e0_e1_k0_k1_e2_grid_desc; - } - - __host__ __device__ static constexpr auto MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor( - const BGridDesc_E0_E1_N_Ho_Wo_E2& b_e0_e1_n_ho_wo_e2_grid_desc) - { - const auto E0 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I0); - // const auto E1 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I1); - const auto N = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I2); - const auto Ho = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I3); - const auto Wo = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I4); - // const auto E2 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I5); - - const auto H2 = Number{}; - const auto H1 = Number{}; - const auto H0 = Ho / (H1 * H2); - - const auto W2 = Number{}; - const auto W1 = Number{}; - const auto W0 = Wo / (W1 * W2); - - const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc = - transform_tensor_descriptor(b_e0_e1_n_ho_wo_e2_grid_desc, - make_tuple(make_pass_through_transform(E0), - make_pass_through_transform(E1), - make_pass_through_transform(N), - make_unmerge_transform(make_tuple(H0, H1, H2)), - make_unmerge_transform(make_tuple(W0, W1, W2)), - make_pass_through_transform(E2)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3, 4, 5>{}, - Sequence<6, 7, 8>{}, - Sequence<9>{})); - - return b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc; - } - - __host__ __device__ static constexpr auto - MakeCK0K1NH0H1H2W0W1W2GridDescriptor(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc) - { - const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0); - const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1); - const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2); - const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3); - - const auto K1 = Number{}; - const auto K0 = K / K1; - - const auto H2 = Number{}; - const auto H1 = Number{}; - const auto H0 = Ho / (H1 * H2); - - const auto W2 = Number{}; - const auto W1 = Number{}; - const auto W0 = Wo / (W1 * W2); - - const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc = transform_tensor_descriptor( - c_k_n_ho_wo_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(K0, K1)), - make_pass_through_transform(N), - make_unmerge_transform(make_tuple(H0, H1, H2)), - make_unmerge_transform(make_tuple(W0, W1, W2))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{})); - - return c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc; - } - - __host__ __device__ static constexpr auto - MakeDK0K1NH0H1HxW0W1WxGridDescriptorMaxPool(const DGridDesc_K_N_Hx_Wx& d_k_n_hx_wx_grid_desc) - { - const auto K = d_k_n_hx_wx_grid_desc.GetLength(I0); - const auto N = d_k_n_hx_wx_grid_desc.GetLength(I1); - const auto Hx = d_k_n_hx_wx_grid_desc.GetLength(I2); - const auto Wx = d_k_n_hx_wx_grid_desc.GetLength(I3); - - const auto K1 = Number{}; - const auto K0 = K / K1; - -#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR - const auto H2 = Number{}; - const auto H1 = Number{}; - const auto H0 = Number{}; - - const auto W2 = Number{}; - const auto W1 = Number{}; - const auto W0 = Number{}; -#else - const auto H2 = HoPerThread / 2; - const auto H1 = HoPerBlock / HoPerThread; - const auto H0 = Hx / (H1 * H2); - - const auto W2 = WoPerThread / 2; - const auto W1 = WoPerBlock / WoPerThread; - const auto W0 = Wx / (W1 * W2); -#endif - - const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc = transform_tensor_descriptor( - d_k_n_hx_wx_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(K0, K1)), - make_pass_through_transform(N), - make_unmerge_transform(make_tuple(H0, H1, H2)), - make_unmerge_transform(make_tuple(W0, W1, W2))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{})); - - return d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc; - } - - __host__ __device__ static constexpr auto - MakeDK0K1NH0H1HxW0W1WxGridDescriptorResizeAdd(const DGridDesc_K_N_Hx_Wx& d_k_n_hx_wx_grid_desc) - { - const auto K = d_k_n_hx_wx_grid_desc.GetLength(I0); - const auto N = d_k_n_hx_wx_grid_desc.GetLength(I1); - const auto Hx = d_k_n_hx_wx_grid_desc.GetLength(I2); - const auto Wx = d_k_n_hx_wx_grid_desc.GetLength(I3); - - const auto K1 = Number{}; - const auto K0 = K / K1; - - const auto H2 = Number{}; - const auto H1 = Number{}; - - const auto W2 = Number{}; - const auto W1 = Number{}; - -#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR - const auto H0 = Number{}; - const auto W0 = Number{}; -#else - const auto H0 = Hx / (H1 * H2); - const auto W0 = Wx / (W1 * W2); -#endif - - const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc = transform_tensor_descriptor( - d_k_n_hx_wx_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(K0, K1)), - make_pass_through_transform(N), - make_unmerge_transform(make_tuple(H0, H1, H2)), - make_unmerge_transform(make_tuple(W0, W1, W2))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{})); - - return d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc; - } - - __host__ __device__ static constexpr auto - MakeCBlockIdToKNHoWoBlockClusterAdaptor(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc) - { - const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0); - const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1); - const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2); - const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3); - -#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR - const auto K0 = Number{}; - const auto N0 = Number{}; - const auto H0 = Number{}; - const auto W0 = Number{}; -#else - const auto K0 = K / KPerBlock; - const auto N0 = N / NPerBlock; - const auto H0 = Ho / HoPerBlock; - const auto W0 = Wo / WoPerBlock; -#endif - - const auto cblockid_to_k_n_ho_wo_block_cluster_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(K0, N0, H0, W0))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - return cblockid_to_k_n_ho_wo_block_cluster_adaptor; - } - - // using AGridDesc_E0_E1_K0_K1_E2 = - // decltype(MakeAE0E1K0K1E2GridDescriptor(AGridDesc_E0_E1_K_E2{})); - // using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 = - // decltype(MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(BGridDesc_E0_E1_N_Ho_Wo_E2{})); - // using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = - // decltype(MakeCK0K1NH0H1H2W0W1W2GridDescriptor(CGridDesc_K_N_Ho_Wo{})); - // using DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx = - // decltype(MakeDK0K1NH0H1HxW0W1WxGridDescriptor(DGridDesc_K_N_Hx_Wx{})); - - using CBlockIdToBlockClusterAdaptor_K_N_H_W = - decltype(MakeCBlockIdToKNHoWoBlockClusterAdaptor(CGridDesc_K_N_Ho_Wo{})); - - template - __host__ __device__ static constexpr auto MakeBiasK0K1GridDescriptor( - const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc) - { - const auto K0 = c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetLength(I0); - const auto K1 = c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetLength(I1); - - return make_naive_tensor_descriptor_packed(make_tuple(K0, K1)); - } - - __host__ __device__ static constexpr auto MakeCK1NH2W2ThreadDescriptor() - { - constexpr auto c_k1_n_h2_w2_thread_gemm_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, I1, Number{}, Number{})); - return c_k1_n_h2_w2_thread_gemm_desc; - } - - // using CThreadDesc_K1_N_H2_W2 = decltype(MakeCK1NH2W2ThreadDescriptor()); - - __host__ __device__ static constexpr auto GetBlockWiseGemm() - { - constexpr auto max_lds_align = Number{}; - - constexpr auto a_e1_k1_e2_block_gemm_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, Number{}), max_lds_align); - - constexpr auto b_e1_n_h_w_e2_block_gemm_desc = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, - I1, - Number{}, - Number{}, - Number{})); - - constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); - - auto blockwise_gemm = - BlockwiseGemmDlops_km_kn_m0m1n0n1_v3{}; - - return blockwise_gemm; - } - - __device__ static constexpr auto GetCThreadIndex() - { - auto blockwise_gemm = GetBlockWiseGemm(); - auto c_thread_mtx_index = - blockwise_gemm.GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id()); - - return c_thread_mtx_index; - }; - - __device__ static constexpr auto GetCBlockIndex( - const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor) - { - const auto c_k_n_h_w_block_cluster_idx = - cblockid_to_k_n_h_w_block_cluster_adaptor.CalculateBottomIndex( - make_multi_index(get_block_1d_id())); - return c_k_n_h_w_block_cluster_idx; - } - - template - __device__ static void BiasOp(BiasGlobalBuff& bias_global_buf, - CThreadBuff& c_thread_buf, - const CBlockIndex& c_block_idx, - const CThreadIndex& c_thread_idx, - const BiasGridDesc_K0_K1& bias_k0_k1_grid_desc, - const CThreadDesc_K1_N_H2_W2&) - - { - const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]); - - const auto k_thread_id = c_thread_idx[I0]; - - constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; - - constexpr auto bias_k0_k1_thread_desc = - make_naive_tensor_descriptor_packed(make_tuple(I1, Number{})); - - StaticBuffer - bias_thread_buf; - - const index_t k_thread_data_on_global = k_thread_id * KPerThread; - - auto bias_threadwise_transfer = - ThreadwiseTensorSliceTransfer_v2{}>, - Sequence<0, 1>, - 1, - CThreadTransferDstScalarPerVector, - false, - true>( - bias_k0_k1_grid_desc, make_multi_index(k_block_work_id, k_thread_data_on_global)); - - constexpr auto bias_k0_k1_global_tensor_step_hacks = make_tuple( - make_tuple(Sequence<0>{}, Sequence<0>{}), make_tuple(Sequence<0>{}, Sequence<0>{})); - - bias_threadwise_transfer.Run(bias_k0_k1_grid_desc, - bias_global_buf, - bias_k0_k1_thread_desc, - make_tuple(I0, I0), - bias_thread_buf, - bias_k0_k1_global_tensor_step_hacks); - - static_for<0, KPerThread, 1>{}([&](auto ki) { - static_for<0, HoPerThread, 1>{}([&](auto hi) { - static_for<0, WoPerThread, 1>{}([&](auto wi) { - constexpr index_t c_offset = - c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(make_tuple(ki, 0, hi, wi)); - c_thread_buf(Number{}) = - c_thread_buf[Number{}] + bias_thread_buf[ki]; - }); - }); - }); - } - - template - __device__ static void Activation(CThreadBuff& c_thread_buf, - const CThreadDesc_K1_N_H2_W2&, - integral_constant) - { - constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; - - static_for<0, c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), 1>{}([&](auto i) { - if constexpr(activ_type_ == 1) - { - c_thread_buf(i) = c_thread_buf[i] >= 0 ? c_thread_buf[i] : alpha * c_thread_buf[i]; - } - else if constexpr(activ_type_ == 2) - { - FloatAcc x = 1.0 + exp(-c_thread_buf[i]); - - asm volatile("\n \ - v_rcp_f32 %0, %1 \n" - : "=v"(x) - : "0"(x)); - - c_thread_buf(i) = x; - } - }); - } - - template - __device__ static void - WriteOut(const CThreadBuff& c_thread_buf, - CGlobalBuff& c_global_buf, - const CBlockIndex& c_block_idx, - const CThreadIndex& c_thread_idx, - const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc) - { - const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]); - const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]); - const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]); - const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]); - - const auto k_thread_id = c_thread_idx[I0]; - const auto ho_thread_id = c_thread_idx[I2]; - const auto wo_thread_id = c_thread_idx[I3]; - - // hack to control index calculation when iterating over c_k_n_h0_h1_h2_w0_w1_w2_global - // tensor - constexpr auto c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks = CGlobalStepHacks{}; - - constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc = - make_naive_tensor_descriptor_packed(make_tuple(I1, - Number{}, - I1, - I1, - I1, - Number{}, - I1, - I1, - Number{})); - - const index_t k_thread_data_on_global = k_thread_id * KPerThread; - - ThreadwiseTensorSliceTransfer_v1r3< - FloatAcc, - FloatC, - decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc), - decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc), - Sequence, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - make_multi_index(k_block_work_id, - k_thread_data_on_global, - n_block_work_id, - ho_block_work_id, - ho_thread_id, - 0, - wo_block_work_id, - wo_thread_id, - 0)) - .Run(c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - c_global_buf, - c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks); - } - - template - __device__ static void - MaxPool(const CThreadBuff& c_thread_buf, - DGlobalBuff& d_global_buf, - const CBlockIndex& c_block_idx, - const CThreadIndex& c_thread_idx, - const CThreadDesc_K1_N_H2_W2&, - const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc) - { - - const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]); - const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]); - const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]); - const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]); - - const auto k_thread_id = c_thread_idx[I0]; - const auto ho_thread_id = c_thread_idx[I2]; - const auto wo_thread_id = c_thread_idx[I3]; - - constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; - - static_assert(HoPerThread % 2 == 0 && WoPerThread % 2 == 0, ""); - - constexpr auto HoPerThread_2 = HoPerThread / 2; - constexpr auto WoPerThread_2 = WoPerThread / 2; - - constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc = - make_naive_tensor_descriptor_packed(make_tuple(I1, - Number{}, - I1, - I1, - I1, - Number{}, - I1, - I1, - Number{})); - - StaticBuffer - d_thread_buf; - - static_for<0, KPerThread, 1>{}([&](auto ki) { - static_for<0, HoPerThread_2, 1>{}([&](auto hi) { - static_for<0, WoPerThread_2, 1>{}([&](auto wi) { - constexpr index_t d_offset = - d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.CalculateOffset( - make_tuple(0, ki, 0, 0, 0, hi, 0, 0, wi)); - - constexpr index_t c_offset_0 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset( - make_tuple(ki, 0, hi * 2, wi * 2)); - constexpr index_t c_offset_1 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset( - make_tuple(ki, 0, hi * 2, wi * 2 + 1)); - constexpr index_t c_offset_2 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset( - make_tuple(ki, 0, hi * 2 + 1, wi * 2)); - constexpr index_t c_offset_3 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset( - make_tuple(ki, 0, hi * 2 + 1, wi * 2 + 1)); - - d_thread_buf(Number{}) = c_thread_buf[Number{}]; - d_thread_buf(Number{}) = - fmaxf(c_thread_buf[Number{}], d_thread_buf(Number{})); - d_thread_buf(Number{}) = - fmaxf(c_thread_buf[Number{}], d_thread_buf(Number{})); - d_thread_buf(Number{}) = - fmax(c_thread_buf[Number{}], d_thread_buf(Number{})); - }); - }); - }); - - const index_t k_thread_data_on_global = k_thread_id * KPerThread; - - constexpr auto d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks = DGlobalStepHacks{}; - - ThreadwiseTensorSliceTransfer_v1r3< - FloatC, - FloatC, - decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc), - decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc), - Sequence, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - InMemoryDataOperationEnum::Set, - 1, - true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, - make_multi_index(k_block_work_id, - k_thread_data_on_global, - n_block_work_id, - ho_block_work_id, - ho_thread_id, - 0, - wo_block_work_id, - wo_thread_id, - 0)) - .Run(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0), - d_thread_buf, - d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, - d_global_buf, - d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks); - } - - template - __device__ static void - ResizeAdd(const CThreadBuff& c_thread_buf, - DGlobalBuff& d_global_buf, - const CBlockIndex& c_block_idx, - const CThreadIndex& c_thread_idx, - const CThreadDesc_K1_N_H2_W2&, - const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc) - { - - const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]); - const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]); - const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]); - const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]); - - const auto k_thread_id = c_thread_idx[I0]; - const auto ho_thread_id = c_thread_idx[I2]; - const auto wo_thread_id = c_thread_idx[I3]; - - constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; - - constexpr auto HoPerThreadx2 = HoPerThread * 2; - constexpr auto WoPerThreadx2 = WoPerThread * 2; - - constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc = - make_naive_tensor_descriptor_packed(make_tuple(I1, - Number{}, - I1, - I1, - I1, - Number{}, - I1, - I1, - Number{})); - - StaticBuffer - d_thread_buf; - - static_for<0, KPerThread, 1>{}([&](auto k_i) { - static_for<0, HoPerThreadx2, 1>{}([&](auto h_i) { - static_for<0, WoPerThreadx2, 1>{}([&](auto w_i) { - d_thread_buf(Number{}) = - c_thread_buf[Number{}]; - }); - }); - }); - - // hack to control index calculation when iterating over d_k_n_ho_wo_global tensor - constexpr auto d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks = DGlobalStepHacks{}; - - const index_t k_thread_data_on_global = k_thread_id * KPerThread; - - ThreadwiseTensorSliceTransfer_v1r3< - FloatC, - FloatC, - decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc), - decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc), - Sequence, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - InMemoryDataOperationEnum::Add, - 1, - true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, - make_multi_index(k_block_work_id, - k_thread_data_on_global, - n_block_work_id, - ho_block_work_id, - ho_thread_id, - 0, - wo_block_work_id, - wo_thread_id, - 0)) - .Run(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0), - d_thread_buf, - d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, - d_global_buf, - d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks); - } - - template - __device__ static void - GemmOp(const AGlobalBuff& a_global_buf, - const BGlobalBuff& b_global_buf, - CThreadBuff& c_thread_buf, - FloatAB* __restrict__ p_shared_block, - const CBlockIndex& c_block_idx, - const CThreadIndex& c_thread_idx, - const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, - const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - const CThreadDesc_K1_N_H2_W2&, - integral_constant) - { - constexpr auto HasMainE1BlockLoop = CalculateHasMainE1BlockLoop(); - constexpr auto HasDoubleTailE1BlockLoop = CalculateHasDoubleTailE1BlockLoop(); - - // const auto c_k_n_h_w_block_cluster_idx = - // GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor); - // cblockid_to_k_n_h_w_block_cluster_adaptor.CalculateBottomIndex( - // make_multi_index(get_block_1d_id())); - - const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]); - const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]); - const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]); - const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]); - - constexpr auto max_lds_align = Number{}; - - constexpr auto a_e1_k1_e2_block_gemm_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, Number{}), max_lds_align); - - constexpr auto b_e1_n_h_w_e2_block_gemm_desc = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, - I1, - Number{}, - Number{}, - Number{})); - - constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; - - auto blockwise_gemm = - BlockwiseGemmDlops_km_kn_m0m1n0n1_v3{}; - // blockwise_gemm.GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id()); - - const auto ho_thread_id = c_thread_idx[I2]; - const auto wo_thread_id = c_thread_idx[I3]; - - constexpr auto a_e0_e1_k0_k1_e2_block_copy_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, I1, Number{}, Number{}), - max_lds_align); - - // A matrix blockwise copy - auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4, - ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2, - ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_e0_e1_k0_k1_e2_grid_desc), - decltype(a_e0_e1_k0_k1_e2_block_copy_desc), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2, 3, 4>, - ABlockTransferSrcVectorDim, - 4, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_E2, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - false>(a_e0_e1_k0_k1_e2_grid_desc, - make_multi_index(0, 0, k_block_work_id, 0, 0), - a_e0_e1_k0_k1_e2_block_copy_desc, - make_multi_index(0, 0, 0, 0, 0)); - - constexpr auto a_block_slice_copy_step = make_multi_index(I1, 0, 0, 0, 0); - - constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc = - make_naive_tensor_descriptor_packed(make_tuple(I1, - Number{}, - I1, - I1, - I1, - Number{}, - I1, - I1, - Number{}, - Number{})); - - auto b_threadwise_transfer = ThreadwiseTensorSliceTransfer_v2< - FloatAB, - FloatAB, - decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc), - decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc), - Sequence, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BThreadTransferSrcResetCoordinateAfterRun, - true>(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - make_multi_index(0, - 0, - n_block_work_id, - ho_block_work_id, - ho_thread_id, - 0, - wo_block_work_id, - wo_thread_id, - 0, - 0)); - - auto a_block_buf = make_dynamic_buffer( - p_shared_block, a_e0_e1_k0_k1_e2_block_copy_desc.GetElementSpaceSize()); - - //// register allocation for output - // StaticBuffer - // c_thread_buf; - - // initialize output thread tensor - ThreadwiseTensorSliceSet_v1>{} - .Run(c_k1_n_h2_w2_thread_gemm_desc, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - FloatAcc{0}); - - constexpr auto b_thread_slice_copy_step = - make_multi_index(0, E1PerBlock, 0, 0, 0, 0, 0, 0, 0, 0); - - // hack to control index calculation when iterating over A and B matrix for threadwise copy - constexpr auto a_e0_e1_k_e2_global_step_hacks = AGlobalStepHacks{}; - constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks = BGlobalStepHacks{}; - - // double regsiter buffer for b - StaticBuffer - b_thread_even_buf, b_thread_odd_buf; - - if constexpr(HasMainE0BlockLoop) - { - const auto E0 = b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetLength(I0); - - index_t e0_block_data_begin = 0; - - do - { - // LDS double buffer: preload data - { - a_blockwise_copy.RunRead( - a_e0_e1_k0_k1_e2_grid_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks); - - b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_global_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), - b_thread_even_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); - - a_blockwise_copy.RunWrite(a_e0_e1_k0_k1_e2_block_copy_desc, a_block_buf); - } - - __syncthreads(); - - if constexpr(HasMainE1BlockLoop) - { - index_t e1_block_data_begin = 0; - - // LDS double buffer: main body - // use Do-While loop instead of For loop to simplify control flow - do - { - // even iteration - b_threadwise_transfer.MoveSrcSliceWindow( - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_thread_slice_copy_step, - BGlobalMoveSliceWindowStepHacks{}); - - b_threadwise_transfer.Run( - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_global_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), - b_thread_odd_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); - - blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); - - b_threadwise_transfer.MoveSrcSliceWindow( - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_thread_slice_copy_step, - BGlobalMoveSliceWindowStepHacks{}); - - b_threadwise_transfer.Run( - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_global_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), - b_thread_even_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); - - blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); - - e1_block_data_begin += 2 * E1PerBlock; - - } while(e1_block_data_begin < E1 - 2 * E1PerBlock); - } - - // LDS double buffer: tail - if constexpr(HasDoubleTailE1BlockLoop) // if has 2 iteration left - { - b_threadwise_transfer.MoveSrcSliceWindow( - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_thread_slice_copy_step, - BGlobalMoveSliceWindowStepHacks{}); - - b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_global_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), - b_thread_odd_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); - - // LDS double buffer: GEMM on 2nd-last data - blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); - - blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); - } - else // if has 1 iteration left - { - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); - } - - a_blockwise_copy.MoveSrcSliceWindow(a_e0_e1_k0_k1_e2_grid_desc, - a_block_slice_copy_step, - AGlobalMoveSliceWindowStepHacks{}); - - blockwise_gemm.MoveABlockSliceWindow(make_tuple(-(E1 - E1PerBlock), 0, 0)); - - b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_thread_slice_copy_step, - BGlobalMoveSliceWindowStepHacks{}); - - e0_block_data_begin += 1; - - } while(e0_block_data_begin < E0); - } - else - { - // LDS double buffer: preload data - { - a_blockwise_copy.RunRead( - a_e0_e1_k0_k1_e2_grid_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks); - - b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_global_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), - b_thread_even_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); - - a_blockwise_copy.RunWrite(a_e0_e1_k0_k1_e2_block_copy_desc, a_block_buf); - } - - __syncthreads(); - - if constexpr(HasMainE1BlockLoop) - { - index_t e1_block_data_begin = 0; - - // LDS double buffer: main body - // use Do-While loop instead of For loop to simplify control flow - do - { - // even iteration - b_threadwise_transfer.MoveSrcSliceWindow( - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_thread_slice_copy_step, - BGlobalMoveSliceWindowStepHacks{}); - - b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_global_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), - b_thread_odd_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); - - blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); - - b_threadwise_transfer.MoveSrcSliceWindow( - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_thread_slice_copy_step, - BGlobalMoveSliceWindowStepHacks{}); - - b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_global_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), - b_thread_even_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); - - blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); - - e1_block_data_begin += 2 * E1PerBlock; - - } while(e1_block_data_begin < E1 - 2 * E1PerBlock); - } - - // LDS double buffer: tail - if constexpr(HasDoubleTailE1BlockLoop) // if has 2 iteration left - { - b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_thread_slice_copy_step, - BGlobalMoveSliceWindowStepHacks{}); - - b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_global_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), - b_thread_odd_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); - - // LDS double buffer: GEMM on 2nd-last data - blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); - - blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); - } - else // if has 1 iteration left - { - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); - } - } - } - - template - __device__ static void - Conv(const FloatAB* __restrict__ p_a_global, - const FloatAB* __restrict__ p_b_global, - const FloatC* __restrict__ p_bias_global, - FloatC* __restrict__ p_c_global, - FloatC* __restrict__ p_d_global, - FloatAB* __restrict__ p_shared_block, - const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, - const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, - const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, - integral_constant) - { - const auto bias_k0_k1_grid_desc = - MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); - - const auto a_global_buf = make_dynamic_buffer( - p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( - p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); - auto c_global_buf = make_dynamic_buffer( - p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize()); - auto d_global_buf = make_dynamic_buffer( - p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize()); - auto bias_global_buf = make_dynamic_buffer( - p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); - - constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); - - // register allocation for output - StaticBuffer - c_thread_buf; - - const auto c_k_n_h_w_block_cluster_idx = - GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor); - - const auto c_thread_mtx_index = GetCThreadIndex(); - - // GemmOp - GemmOp(a_global_buf, - b_global_buf, - c_thread_buf, - p_shared_block, - c_k_n_h_w_block_cluster_idx, - c_thread_mtx_index, - a_e0_e1_k0_k1_e2_grid_desc, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - c_k1_n_h2_w2_thread_gemm_desc, - integral_constant{}); - - // Output - WriteOut(c_thread_buf, - c_global_buf, - c_k_n_h_w_block_cluster_idx, - c_thread_mtx_index, - c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); - } - - template - __device__ static void ConvBiasActiv( - const FloatAB* __restrict__ p_a_global, - const FloatAB* __restrict__ p_b_global, - const FloatC* __restrict__ p_bias_global, - FloatC* __restrict__ p_c_global, - FloatAB* __restrict__ p_shared_block, - const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, - const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, - integral_constant, - integral_constant) - { - static constexpr auto activ_type = integral_constant{}; - - const auto bias_k0_k1_grid_desc = - MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); - - const auto a_global_buf = make_dynamic_buffer( - p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( - p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); - auto c_global_buf = make_dynamic_buffer( - p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize()); - auto bias_global_buf = make_dynamic_buffer( - p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); - - constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); - - // register allocation for output - StaticBuffer - c_thread_buf; - - const auto c_k_n_h_w_block_cluster_idx = - GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor); - - const auto c_thread_mtx_index = GetCThreadIndex(); - - // GemmOp - GemmOp(a_global_buf, - b_global_buf, - c_thread_buf, - p_shared_block, - c_k_n_h_w_block_cluster_idx, - c_thread_mtx_index, - a_e0_e1_k0_k1_e2_grid_desc, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - c_k1_n_h2_w2_thread_gemm_desc, - integral_constant{}); - - // Bias - BiasOp(bias_global_buf, - c_thread_buf, - c_k_n_h_w_block_cluster_idx, - c_thread_mtx_index, - bias_k0_k1_grid_desc, - c_k1_n_h2_w2_thread_gemm_desc); - - // Activ - Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type); - - // Output - WriteOut(c_thread_buf, - c_global_buf, - c_k_n_h_w_block_cluster_idx, - c_thread_mtx_index, - c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); - } - - template - __device__ static void ConvBiasActivMaxpool( - const FloatAB* __restrict__ p_a_global, - const FloatAB* __restrict__ p_b_global, - const FloatC* __restrict__ p_bias_global, - FloatC* __restrict__ p_c_global, - FloatC* __restrict__ p_d_global, - FloatAB* __restrict__ p_shared_block, - const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, - const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, - const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, - integral_constant, - integral_constant) - { - static constexpr auto activ_type = integral_constant{}; - - const auto bias_k0_k1_grid_desc = - MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); - - const auto a_global_buf = make_dynamic_buffer( - p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( - p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); - auto c_global_buf = make_dynamic_buffer( - p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize()); - auto d_global_buf = make_dynamic_buffer( - p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize()); - auto bias_global_buf = make_dynamic_buffer( - p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); - - constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); - - // register allocation for output - StaticBuffer - c_thread_buf; - - const auto c_k_n_h_w_block_cluster_idx = - GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor); - - const auto c_thread_mtx_index = GetCThreadIndex(); - - // GemmOp - GemmOp(a_global_buf, - b_global_buf, - c_thread_buf, - p_shared_block, - c_k_n_h_w_block_cluster_idx, - c_thread_mtx_index, - a_e0_e1_k0_k1_e2_grid_desc, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - c_k1_n_h2_w2_thread_gemm_desc, - integral_constant{}); - - // Bias - BiasOp(bias_global_buf, - c_thread_buf, - c_k_n_h_w_block_cluster_idx, - c_thread_mtx_index, - bias_k0_k1_grid_desc, - c_k1_n_h2_w2_thread_gemm_desc); - - // Activ - Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type); - - // Output - WriteOut(c_thread_buf, - c_global_buf, - c_k_n_h_w_block_cluster_idx, - c_thread_mtx_index, - c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); - - // MaxPool - MaxPool(c_thread_buf, - d_global_buf, - c_k_n_h_w_block_cluster_idx, - c_thread_mtx_index, - c_k1_n_h2_w2_thread_gemm_desc, - d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc); - } - - template - __device__ static void ConvBiasActivResizeAdd( - const FloatAB* __restrict__ p_a_global, - const FloatAB* __restrict__ p_b_global, - const FloatC* __restrict__ p_bias_global, - FloatC* __restrict__ p_d_global, - FloatAB* __restrict__ p_shared_block, - const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, - const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, - const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, - integral_constant, - integral_constant) - { - static constexpr auto activ_type = integral_constant{}; - - const auto bias_k0_k1_grid_desc = - MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); - - const auto a_global_buf = make_dynamic_buffer( - p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( - p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); - auto d_global_buf = make_dynamic_buffer( - p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize()); - auto bias_global_buf = make_dynamic_buffer( - p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); - - constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); - - // register allocation for output - StaticBuffer - c_thread_buf; - - const auto c_k_n_h_w_block_cluster_idx = - GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor); - - const auto c_thread_mtx_index = GetCThreadIndex(); - - // GemmOp - GemmOp(a_global_buf, - b_global_buf, - c_thread_buf, - p_shared_block, - c_k_n_h_w_block_cluster_idx, - c_thread_mtx_index, - a_e0_e1_k0_k1_e2_grid_desc, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - c_k1_n_h2_w2_thread_gemm_desc, - integral_constant{}); - - // Bias - BiasOp(bias_global_buf, - c_thread_buf, - c_k_n_h_w_block_cluster_idx, - c_thread_mtx_index, - bias_k0_k1_grid_desc, - c_k1_n_h2_w2_thread_gemm_desc); - - // Activ - Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type); - - // Resize_Add - ResizeAdd(c_thread_buf, - d_global_buf, - c_k_n_h_w_block_cluster_idx, - c_thread_mtx_index, - c_k1_n_h2_w2_thread_gemm_desc, - d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc); - } -}; - -} // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index 578665ea85fb723565e597bbec26457ffb449606..7289a20da03523b0496353a0cea48e0568c50de0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index 1e8f8ff9fe2db6442f6d39ea3839f4eab8dbf429..119c1ea596dc6c7f3cff9b20a94c9f2dbe85e4d4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -431,6 +431,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + constexpr auto cshuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + constexpr auto max_lds_align = K1; constexpr auto a_block_space_size_aligned = math::integer_least_multiple( @@ -439,8 +442,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle constexpr auto b_block_space_size_aligned = math::integer_least_multiple( b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align); - return (a_block_space_size_aligned * sizeof(ADataType) + - b_block_space_size_aligned * sizeof(BDataType)); + constexpr auto c_block_space_size_aligned = math::integer_least_multiple( + cshuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize(), + max_lds_align); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)), + c_block_space_size_aligned * sizeof(CShuffleDataType)); } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} @@ -497,6 +505,15 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle } // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + return true; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index da0b0cea241d7b62dd8e138aa5dd7cb0ccdaa0cd..ec1cc53991d128a8f281e01289a6a8625810971a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -92,6 +92,17 @@ struct GridwiseGemmMultipleD_xdl_cshuffle using GridwiseGemmPipe = remove_cvref_t())>; + // denorm test fix, required to work around fp16 mfma issue + // we convert fp16->fp32->bf16 and execute bf16 mfma instruction + // when mfma if fixed, remove this section and update + // ABDataTypeAdjusted -> ABDataType throughout this file +#if CK_WORKAROUND_DENORM_FIX + using ABDataTypeAdjusted = + conditional_t, ck::bhalf_t, ABDataType>; +#else + using ABDataTypeAdjusted = ABDataType; +#endif + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { // A matrix in LDS memory, dst of blockwise copy @@ -397,7 +408,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABDataType, - ABDataType, + ABDataTypeAdjusted, decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), ABlockTransferSrcAccessOrder, @@ -428,7 +439,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, ABDataType, - ABDataType, + ABDataTypeAdjusted, decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1), BBlockTransferSrcAccessOrder, @@ -458,11 +469,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle // sanity check constexpr index_t KPack = math::max(math::lcm(AK1, BK1), - MfmaSelector::selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - ABDataType, + ABDataTypeAdjusted, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), @@ -480,10 +491,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared), + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp index 98331d854490abe4e7f478d817104218cb49c28f..9090c16dcdd08ea39d1a2a591d8c605b20a55248 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp index e9097552c2baf236b68370d6d98a771a2eec4e05..d1209636de93a13cbe7f851575b4abaedd99af72 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp index 3281b910d3d8782fc18f91ee77ce42aa44fc5ba9..f54345b04f2c8da02365d3e123a660b04f51ad58 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index 2fe55068449377a82b0c4bd12c65f05ea4e96199..3d1fbb73bd0087c682f9a72a2b6fb3d1f411cf62 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -54,7 +54,8 @@ __global__ void const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp index aa89bff9ee21873fdc57c839eb800be455005a9a..e7577bdcbd52b87cf986e0c9f92cc494089b9423 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp index 2d3a36fca08d7811e406a89b398f401b741a7649..de5a4241986ff64e1797f2304da117178059ea4b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index 1fee302c3c78470d2c9f0b5125bde454a7b5abeb..740dde5c65f3d02282d6a0e5b05d41c6591de72f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -264,6 +264,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma } // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(FloatA) <= TwoGB && + b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(FloatB) <= TwoGB)) + { + return false; + } return true; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index ecc528a7ed79a337efd02cc5efe9dd11039b3b29..7f1fb21d33efe67f4bebc8660249b755b9e51914 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -17,17 +17,25 @@ namespace ck { -template +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -35,54 +43,33 @@ __global__ void kernel_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2CTileMap block_2_ctile_map) + typename GridwiseGemm::Problem problem) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - a_element_op, - b_element_op, - c_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_ctile_map); + GridwiseGemm::template Run(p_a_grid, p_b_grid, p_c_grid, p_shared, problem); #else ignore = p_a_grid; ignore = p_b_grid; ignore = p_c_grid; - ignore = a_element_op; - ignore = b_element_op; - ignore = c_element_op; - ignore = a_grid_desc_ak0_m_ak1; - ignore = b_grid_desc_bk0_n_bk1; - ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = block_2_ctile_map; + ignore = problem; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } -template {}; // K1 should be Number<...> - static constexpr auto AK0 = Number{}; - static constexpr auto BK0 = Number{}; - static constexpr auto AK1 = Number{}; - static constexpr auto BK1 = Number{}; + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; using ThisThreadBlock = ThisThreadBlock; + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock) * MPerBlock; + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0(index_t K) + { + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + return CalculateKPadded(K) / AK1Value; + } + else + { + return K / AK1Value; + } + } + + __host__ static auto CalculateBK0(index_t K) + { + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + return CalculateKPadded(K) / BK1Value; + } + else + { + return K / BK1Value; + } + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_floor(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_floor(N, NPerBlock); + } + + __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KPadded{CalculateKPadded(K_)}, + AK0{CalculateAK0(K_)}, + BK0{CalculateBK0(K_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const FloatAB* p_a_grid_, + const FloatAB* p_b_grid_, + FloatC* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + { + } + + const FloatAB* p_a_grid; + const FloatAB* p_b_grid; + FloatC* p_c_grid; + }; + // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm using GridwiseGemmPipe = remove_cvref_t())>; - __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { // A matrix in LDS memory, dst of blockwise copy return make_naive_tensor_descriptor( - make_tuple(AK0, Number{}, AK1), - make_tuple(Number{} * AK1, AK1, I1)); + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(Number{} * AK1Number, AK1Number, I1)); } - __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { // B matrix in LDS memory, dst of blockwise copy return make_naive_tensor_descriptor( - make_tuple(BK0, Number{}, BK1), - make_tuple(Number{} * BK1, BK1, I1)); + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(Number{} * BK1Number, BK1Number, I1)); } - __host__ __device__ static constexpr auto - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() { constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); @@ -171,14 +519,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; } - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { // LDS allocation for A and B: be careful of alignment constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1, BK1); + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); constexpr auto a_block_space_size_aligned = math::integer_least_multiple( a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); @@ -199,36 +547,102 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} - template - __host__ __device__ static constexpr bool - CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, - const CGridDesc_M_N& c_grid_desc_m_n, - const Block2CTileMap& block_2_ctile_map) + __host__ static constexpr bool CheckValidity(const Problem& problem) { static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); - const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); - const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); - const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.M % MPerBlock == 0)) + { + return false; + } + } - if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) - return false; + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.N % NPerBlock == 0)) + { + return false; + } + } - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) - return false; + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding) + { + if(!(CalculateKPadded(problem.K) % AK1Value == 0) || + !(CalculateKPadded(problem.K) % BK1Value == 0)) + { + return false; + } + } + else + { + if(!(problem.K % AK1Value == 0) || !(problem.K % BK1Value == 0)) + { + return false; + } + } - // check gridwise gemm pipeline - const auto num_k_loop = K / KPerBlock; + if constexpr(is_same::value) + { + if(problem.K % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.M % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } - if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + if constexpr(is_same::value) { - return false; + if(problem.N % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.K % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } } - if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + if constexpr(is_same::value) + { + if(problem.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + if(problem.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = (CalculateAK0(problem.K) * AK1Value) / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { return false; } @@ -237,22 +651,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 return true; } - __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { const index_t num_loop = K / KPerBlock; return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } - __host__ __device__ static constexpr auto - MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - const auto MBlock = M / MPerBlock; - const auto NBlock = N / NPerBlock; - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( c_grid_desc_m_n, make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), @@ -264,33 +673,26 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 } // return block_id to C matrix tile idx (m0, n0) mapping - __host__ __device__ static constexpr auto - MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) - { - return BlockToCTileMap_M00_N0_M01Adapt( - c_grid_desc_m_n); - } - - using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - - using DefaultBlock2CTileMap = - remove_cvref_t; + using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt; - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, void* __restrict__ p_shared, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CElementwiseOperation& c_element_op, - const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& - c_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2CTileMap& block_2_ctile_map) + const Problem& problem) { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( @@ -298,7 +700,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N}; + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); @@ -318,7 +726,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1, BK1); + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); @@ -332,7 +740,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough, InMemoryDataOperationEnum::Set, - Sequence, + Sequence, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, FloatAB, @@ -363,7 +771,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough, InMemoryDataOperationEnum::Set, - Sequence, + Sequence, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, FloatAB, @@ -395,8 +803,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check - constexpr index_t KPack = math::max( - math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + constexpr index_t KPack = + math::max(math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, @@ -424,8 +833,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_bk0_n_bk1.GetElementSpaceSize()); - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); // gridwise GEMM pipeline static_assert(std::is_default_constructible_v); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp index 94e181cd45470c08f4f7b35b84ff991a56babb82..8675350094b385824d39d75d19ce74ec41ef2778 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -57,7 +57,8 @@ __global__ void const C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; // TODO ANT: separate into MMA + Epilogue diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp index acece0fbba42bc322cb3102aaf8898a77b41b9d0..31c59d14eddb6983c43ff97b9999f78aa7f980b2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 126887cbacd727a7c5a3d2026599e0a9c4d9e2e3..0fc18bb923f10e2bd055b0df4ad89be332766620 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -165,16 +165,14 @@ __global__ void const CElementwiseOperation c_element_op, const CBlockClusterAdaptor c_block_cluster_adaptor) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, p_b_grid, p_c_grid, - p_shared_block, + p_shared, a_b_k0_m_k1_grid_desc, b_b_k0_n_k1_grid_desc, c_grid_desc_mblock_mperblock_nblock_nperblock, @@ -183,16 +181,16 @@ __global__ void c_element_op, c_block_cluster_adaptor); #else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_c_grid; - ignore = a_b_k0_m_k1_grid_desc; - ignore = b_b_k0_n_k1_grid_desc; - ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = a_element_op; - ignore = b_element_op; - ignore = c_element_op; - ignore = c_block_cluster_adaptor; + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_b_k0_m_k1_grid_desc; + ignore = b_b_k0_n_k1_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = c_block_cluster_adaptor; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -264,6 +262,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight using GridwiseGemmPipe = remove_cvref_t())>; + // denorm test fix, required to work around fp16 mfma issue + // we convert fp16->fp32->bf16 and execute bf16 mfma instruction + // when mfma if fixed, remove this section and update + // FloatABAdjusted -> FloatAB throughout this file +#if CK_WORKAROUND_DENORM_FIX + using FloatABAdjusted = conditional_t, ck::bhalf_t, FloatAB>; +#else + using FloatABAdjusted = FloatAB; +#endif + // M0/M1/M1Padding static constexpr auto M1PerBlock = Number{}; static constexpr auto M0PerBlock = Number{}; @@ -605,7 +613,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, - FloatAB* __restrict__ p_shared_block, + void* __restrict__ p_shared, const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& @@ -666,7 +674,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, FloatAB, - FloatAB, + FloatABAdjusted, decltype(a_b_k0_m_k1_grid_desc), decltype(a_b_k0_m_k1_block_desc), ABlockTransferSrcAccessOrder, @@ -696,7 +704,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, FloatAB, - FloatAB, + FloatABAdjusted, decltype(b_b_k0_n_k1_grid_desc), decltype(b_b_k0_n_k1_block_desc), BBlockTransferSrcAccessOrder, @@ -725,11 +733,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight // sanity check constexpr index_t KPack = - math::max(K1, MfmaSelector::selected_mfma.k_per_blk); + math::max(K1, MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1( - p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); + static_cast(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer( - p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); + static_cast(p_shared) + a_block_space_size, + b_k0_n_k1_block_desc.GetElementSpaceSize()); // gridwise GEMM pipeline const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); @@ -798,8 +805,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - void* p_shared = static_cast(p_shared_block); - auto c_block_buf = make_dynamic_buffer( static_cast(p_shared), c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp index 2aad7128f0681bcd903c2c11249380e3883c3e53..b12bcee0f414c6e56b6550c084630b86ab153bb5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -44,7 +44,8 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index d1149c0c2e3c80305e6cb415568cb1fa1fbb7a39..7b8bbd301eefc3b29d4f62e1ff5ccde8e066aeb4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,7 @@ #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" @@ -21,29 +22,21 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v2r3( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDesc_M_N c_grid_desc_m_n) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -52,22 +45,46 @@ __global__ void p_shared, a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + c_grid_desc_m_n); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_m_n; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const auto a_grid_desc_k0_m_k1 = + amd_wave_read_first_lane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1( + karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA)); + const auto b_grid_desc_k0_n_k1 = + amd_wave_read_first_lane(GridwiseGemm::MakeBGridDescriptor_K0_N_K1( + karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB)); + const auto c_grid_desc_m_n = amd_wave_read_first_lane(GridwiseGemm::MakeCGridDescriptor_M_N( + karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC)); + + GridwiseGemm::template Run(karg.p_a_grid, + karg.p_b_grid, + karg.p_c_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m_n); #else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_c_grid; - ignore = a_grid_desc_k0_m_k1; - ignore = b_grid_desc_k0_n_k1; - ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2; - ignore = a_element_op; - ignore = b_element_op; - ignore = c_element_op; - ignore = block_2_ctile_map; + ignore = karg; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -76,9 +93,6 @@ template ; + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); + } + + template + __host__ static auto CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(c_grid_desc_m_n), 1, 1); + } + + template + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock) * MPerBlock; + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; + } + + __host__ static auto CalculateK0(index_t K) { return math::integer_divide_floor(K, K1Value); } + + // Argument + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + K0{CalculateK0(K)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "K0:" << K0 << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t K0; + }; + + // Argument + struct Argument : public Problem, public tensor_operation::device::BaseArgument + { + __host__ Argument(const FloatAB* p_a_grid_, + const FloatAB* p_b_grid_, + FloatC* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + { + } + + const FloatAB* p_a_grid; + const FloatAB* p_b_grid; + FloatC* p_c_grid; + }; + using GridwiseGemmPipe = remove_cvref_t())>; + // denorm test fix, required to work around fp16 mfma issue + // we convert fp16->fp32->bf16 and execute bf16 mfma instruction + // when mfma if fixed, remove this section and update + // FloatABAdjusted -> FloatAB throughout this file +#if CK_WORKAROUND_DENORM_FIX + using FloatABAdjusted = conditional_t, ck::bhalf_t, FloatAB>; +#else + using FloatABAdjusted = FloatAB; +#endif + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() { constexpr auto max_lds_align = K1; @@ -193,13 +316,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB); } - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} - template + template __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDesc_M_N& c_grid_desc_m_n, - const Block2CTileMap& block_2_ctile_map) + const CGridDesc_M_N& c_grid_desc_m_n) { static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); @@ -228,7 +349,24 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return false; } - if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CheckValidity(const Problem& problem) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXDL)) == 0, + "Invalid tuning param!"); + + // check gridwise gemm pipeline + const index_t K0 = problem.K / K1Value; + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { return false; } @@ -237,15 +375,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return true; } - __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { const index_t num_loop = K / (K0PerBlock * K1); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } + template __host__ __device__ static constexpr auto - MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc& c_grid_desc_m_n) { constexpr auto max_lds_align = K1; @@ -281,7 +420,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 using BlockwiseGemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + + template + __device__ static void Run(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n) { - return BlockToCTileMap_M00_N0_M01Adapt( - c_grid_desc_m_n); - } + const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); - using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); - using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); - - template - __device__ static void - Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - void* __restrict__ p_shared, - const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CElementwiseOperation& c_element_op, - const Block2CTileMap& block_2_ctile_map) - { const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( @@ -327,7 +458,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + const auto block_2_ctile_map = + Block2CTileMap{c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)}; // divide block work by [M, N] const auto block_work_idx = @@ -367,7 +503,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, FloatAB, - FloatAB, + FloatABAdjusted, decltype(a_grid_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1), ABlockTransferSrcAccessOrder, @@ -398,7 +534,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, FloatAB, - FloatAB, + FloatABAdjusted, decltype(b_grid_desc_k0_n_k1), decltype(b_block_desc_k0_n_k1), BBlockTransferSrcAccessOrder, @@ -428,7 +564,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // sanity check auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - FloatAB, + FloatABAdjusted, FloatAcc, decltype(a_block_desc_k0_m_k1), decltype(b_block_desc_k0_n_k1), @@ -446,16 +582,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_k0_n_k1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); // gridwise GEMM pipeline + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, @@ -554,4 +691,309 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 } }; +template +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext + : GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 +{ + using Parent = + GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + + using typename Parent::GridwiseGemmPipe; + using typename Parent::Problem; + + using Parent::I1; + + using Parent::K1; + + __device__ static auto + MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t K0, index_t StrideA) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + __device__ static auto + MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t NPad, index_t K0, index_t StrideB) + { + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + return transform_tensor_descriptor(c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + __host__ static constexpr bool CheckValidity(const Problem& problem) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXDL)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.M % MPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.N % NPerBlock == 0)) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.K % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.M % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.N % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.K % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + // check gridwise gemm pipeline + const index_t K0 = problem.K / K1; + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } +}; + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp index 949d56483669946f12f40adca568d4e2a6b0bb5b..19fbee727f77ae8b6b6f90a5a6f520accde95332 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -42,7 +42,8 @@ __global__ void const CElementwiseOperation c_element_op, const CBlockClusterAdaptor c_block_cluster_adaptor) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index 190194f1eb18945eb0e6dc3334e9e6a5a5248267..4cd70cc90288342298937845cca09573a40f5881 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -8,70 +8,39 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" namespace ck { template + bool HasMainKBlockLoop, + InMemoryDataOperationEnum CGlobalMemoryDataOperation, + typename Block2CTileMap> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v2r4r2(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, - const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const CBlockClusterAdaptor c_block_cluster_adaptor) + kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg, + const Block2CTileMap& b2c_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - static_cast(p_shared_block), - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_grid_desc_mblock_mperblock_nblock_nperblock, - a_element_op, - b_element_op, - c_element_op, - c_block_cluster_adaptor); +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + + __shared__ uint8_t p_shared[shared_size]; + + GridwiseGemm::template Run( + karg, static_cast(p_shared), b2c_map); #else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_c_grid; - ignore = a_b_k0_m_k1_grid_desc; - ignore = b_b_k0_n_k1_grid_desc; - ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = a_element_op; - ignore = b_element_op; - ignore = c_element_op; - ignore = c_block_cluster_adaptor; + ignore = karg; + ignore = b2c_map; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -79,13 +48,14 @@ template + typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + LoopScheduler LoopSched = make_default_loop_scheduler(), + PipelineVersion PipelineVer = PipelineVersion::v1> struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { static constexpr auto I0 = Number<0>{}; @@ -126,10 +98,229 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 static constexpr auto I7 = Number<7>{}; // K1 should be Number<...> - static constexpr auto K1 = Number{}; + static constexpr auto K1 = Number{}; + static constexpr auto M01 = 1; + static constexpr auto N01 = 1; + + static constexpr auto gemm_padder = + tensor_operation::device::GemmPadder{ + MPerBlock, NPerBlock, K1* K0PerBlock}; using ThisThreadBlock = ThisThreadBlock; + using GridwiseGemmPipe = remove_cvref_t())>; + + struct Argument : public ck::tensor_operation::device::BaseArgument + { + const FloatAB* p_a_grid; + const FloatAB* p_b_grid; + FloatC* p_c_grid; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t KPadded; + index_t K0; + index_t k_batch; + + Argument(const FloatAB* p_a_grid_, + const FloatAB* p_b_grid_, + FloatC* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t MPadded_, + index_t NPadded_, + index_t KPadded_, + index_t K0_, + index_t k_batch_) + : p_a_grid(p_a_grid_), + p_b_grid(p_b_grid_), + p_c_grid(p_c_grid_), + M(M_), + N(N_), + K(K_), + StrideA(StrideA_), + StrideB(StrideB_), + StrideC(StrideC_), + MPadded(MPadded_), + NPadded(NPadded_), + KPadded(KPadded_), + K0(K0_), + k_batch(k_batch_) + { + } + + void Print() const + { + std::cout << "arg {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " + << "K0:" << K0 << ", " + << "KB:" << k_batch << "}" << std::endl; + } + }; + + __host__ __device__ static auto CalculateGridSize(const Argument& karg) + { + return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock), + math::integer_divide_ceil(karg.M, MPerBlock), + karg.k_batch); + } + + // prefer this to be called on host + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1) + { + // k_batch * k0 * k0_per_block * k1 + auto K_t = K_Batch * K0PerBlock * K1; + return (K + K_t - 1) / K_t * K0PerBlock; + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K0 = CalculateK0(K, K_Batch); + return K_Batch * K0 * K1; + } + + __host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, + index_t MPad, + index_t K, + index_t StrideA, + index_t KBatch, + index_t K0, + index_t KPad) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + __host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, + index_t NPad, + index_t N, + index_t StrideB, + index_t KBatch, + index_t K0, + index_t KPad) + { + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + __host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + return gemm_padder.PadCDescriptor_M_N(c_grid_desc_m_n); + } + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { constexpr auto max_lds_align = K1; @@ -178,54 +369,161 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 c_block_size * sizeof(FloatC)); } - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} - template - __host__ __device__ static constexpr bool - CheckValidity(const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, - const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, - const CMNGridDesc& c_m_n_grid_desc, - const Block2CTileMap& block_2_ctile_map) + __host__ __device__ static constexpr bool CheckValidity(const Argument& karg) { - static_assert(is_known_at_compile_time>::value, - "wrong! K1 need to be known at compile-time"); - - static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && - (NPerBlock % (NRepeat * NPerXDL)) == 0, - "Invalid tuning param!"); - - const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2); - const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2); - const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); - const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0); - - if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && - K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) && - K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) && - K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) && - KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0))) - return false; + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.M % MPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) - return false; +#endif // DEBUG_LOG + return false; + } + } + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.N % NPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } - if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc)) + if constexpr(is_same::value) { + if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + { +#if DEBUG_LOG + std::cout + << "Arg N (" << karg.N + << ") value is not a multiple of CBlockTransferScalarPerVector_NWaveNPerXDL (" + << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + { +#if DEBUG_LOG + std::cout + << "Arg M (" << karg.M + << ") value is not a multiple of CBlockTransferScalarPerVector_NWaveNPerXDL (" + << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + const auto num_k_loop = karg.K0 / K0PerBlock; + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { +#if DEBUG_LOG + std::cout << "The number of k loops (" << num_k_loop + << ") value is not supported by GridwiseGemm Pipeline." + << " K0: " << karg.K0 << ", K0PerBlock: " << K0PerBlock << " " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; +#endif // DEBUG_LOG return false; } - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + __host__ __device__ static auto GetKPad(index_t K, index_t KBatch) { - const bool has_main_k0_block_loop = K0 > K0PerBlock; + const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; + const index_t KPad = KBatch * K0 * K1; + return KPad; + } - return has_main_k0_block_loop; + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + const index_t num_loop = K0 / K0PerBlock; + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } + template __host__ __device__ static constexpr auto - MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CMNGridDesc& c_m_n_grid_desc) + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc) { const auto M = c_m_n_grid_desc.GetLength(I0); const auto N = c_m_n_grid_desc.GetLength(I1); @@ -242,10 +540,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 } // return block_id to C matrix tile idx (m0, n0) mapping + template __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( - const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch) + const CGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch) { - return BlockToCTileMap_KSplit_M00_N0_M01Adapt( + return BlockToCTileMap_KSplit_M00_N0_M01Adapt( c_m_n_grid_desc, 8, KBatch); } @@ -262,24 +561,37 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 Number{})); } - using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{})); - using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); + // return block_id to C matrix tile idx (m0, n0, k_split) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap() + { + return BlockToCTileMap_3DGrid_KSplit(); + } + + using CGridDesc_M_N = remove_cvref_t; + using DefaultBlock2CTileMap = remove_cvref_t; - template - __device__ static void Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, + template + __device__ static void Run(const Argument& karg, void* __restrict__ p_shared_block, - const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, - const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& - c_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CElementwiseOperation& c_element_op, - const CBlockClusterAdaptor& c_block_cluster_adaptor) + const Block2CTileMap& block_2_ctile_map) { + const FloatAB* p_a_grid = karg.p_a_grid; + const FloatAB* p_b_grid = karg.p_b_grid; + FloatC* p_c_grid = karg.p_c_grid; + const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1( + karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0, karg.KPadded); + const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1( + karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0, karg.KPadded); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const AElementwiseOperation a_element_op = AElementwiseOperation{}; + const BElementwiseOperation b_element_op = BElementwiseOperation{}; + const CElementwiseOperation c_element_op = CElementwiseOperation{}; + const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( @@ -287,28 +599,28 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); - - // divide block work by [M, N] + // divide block work by [KBatch, M, N] const auto block_work_idx = - c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - if(!c_block_cluster_adaptor.ValidCTileIndex( - make_tuple(block_work_idx[I1], block_work_idx[I2]), + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) { return; } - const index_t k_batch_id = block_work_idx[I0]; + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I2]); + const index_t k_batch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); // HACK: this force m/n_block_data_idx_on_grid into SGPR const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); // lds max alignment constexpr auto max_lds_align = K1; @@ -445,17 +757,18 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 // register // sanity check - auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatAB, + FloatAcc, + decltype(a_k0_m_k1_block_desc), + decltype(b_k0_n_k1_block_desc), + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + K1, + LoopSched>(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); @@ -474,6 +787,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 auto b_block_buf = make_dynamic_buffer( p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); +#if 0 // preload data into LDS { a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf); @@ -510,7 +824,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (K0 - K0PerBlock)); + } while(k0_block_data_begin < (karg.K0 - K0PerBlock)); } // tail @@ -519,6 +833,30 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); } +#else + // gridwise GEMM pipeline + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) / + (K0PerBlock * K1)); + + const auto gridwise_gemm_pipeline = GridwiseGemmPipe{}; + + gridwise_gemm_pipeline.template Run(a_b_k0_m_k1_grid_desc, + a_b_k0_m_k1_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_b_k0_n_k1_grid_desc, + b_b_k0_n_k1_block_desc, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); +#endif // output: register to global memory { @@ -647,7 +985,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 {c_block_desc_mblock_mperblock_nblock_nperblock, make_multi_index(0, 0, 0, 0), c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0), + make_multi_index(block_m_id, 0, block_n_id, 0), c_element_op}; constexpr auto mxdlperwave_forward_step = @@ -716,6 +1054,30 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 }); } } + + static std::string GetTypeString() + { + auto str = std::stringstream(); + + // clang-format off + str << "GemmXdlSplitKCShuffle_" + << getGemmSpecializationString(GemmSpec) << "_" + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << "_" + << "B" << BlockSize << "_" + << "Vec" << ABlockTransferSrcScalarPerVector << "x" + << BBlockTransferSrcScalarPerVector << "x" + << CBlockTransferScalarPerVector_NWaveNPerXDL << "_" + << MPerBlock << "x" + << NPerBlock << "x" + << K0PerBlock << "x" + << K1 ; + // clang-format on + + return str.str(); + } }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp index ffb2926c877e1569847da43fe08c05e30073dc84..d090ba54d88095cb67fa97ad58146a74cef3e346 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -46,7 +46,8 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp index 7e6dbb3b2ed85f55ec062a7b60cedc7b40337d54..2c09e80fdc7e2987dbb53be557187d05e2a961cf 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -49,7 +49,8 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp index fb1e34b98543665717624af2201d9cbfe4596cda..c1bc6a8fec37f3ac9d94a371d9426ad108622d3b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -53,7 +53,8 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp index de1ae915920503a086d2cf939ae6e41d27a28838..61d0f9e0d5573c774bbb1dcba707e8c2e775f075 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8f72f88b068afc4a15e94b16dc659a2b2b9f992a --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp @@ -0,0 +1,155 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +__global__ void kernel_put_element_1d(const InGrid1dDesc in_grid_1d_desc, + const InDataType* __restrict__ p_in_global, + const IndexDataType* __restrict__ p_indices_global, + OutDataType* __restrict__ p_out_global, + const ElementwiseOperation elementwise_op) +{ + GridwisePutElementwise1dFunctor::Run( + in_grid_1d_desc, p_in_global, p_indices_global, p_out_global, elementwise_op); +} + +// output[indices] = input +template +struct GridwisePutElement_1D +{ + static constexpr auto I0 = Number<0>{}; + + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + __device__ static void Run(const InGrid1dDesc& in_grid_1d_desc, + const InDataType* __restrict__ p_in_global, + const IndexDataType* __restrict__ p_indices_global, + OutDataType* __restrict__ p_out_global, + const ElementwiseOperation& elementwise_op) + { + // Global Memory + const auto in_global_buf = make_dynamic_buffer( + p_in_global, in_grid_1d_desc.GetElementSpaceSize()); + + const auto indices_global_buf = + make_dynamic_buffer(p_indices_global, + in_grid_1d_desc.GetElementSpaceSize(), + NumericLimits::Lowest()); + + // VGPR + StaticBuffer in_thread_buf; + StaticBuffer indices_thread_buf; + + // Thread id, Block id and index + const index_t thread_global_id = get_thread_global_1d_id(); + const auto thread_global_offset = make_multi_index(thread_global_id * InVectorSize); + const index_t blockSize = get_block_size(); + const index_t blockPerGrid = get_grid_size(); + const auto M = in_grid_1d_desc.GetLength(I0); + const index_t loop_step = blockPerGrid * blockSize * InVectorSize; + const auto loop_step_index = make_multi_index(loop_step); + + auto in_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + InVectorSize, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{in_grid_1d_desc, thread_global_offset}; + + auto indices_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + InVectorSize, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{in_grid_1d_desc, thread_global_offset}; + + index_t num_iter = M / loop_step; + do + { + in_global_load.Run(in_grid_1d_desc, + in_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + in_thread_buf); + + in_global_load.MoveSrcSliceWindow(in_grid_1d_desc, loop_step_index); + + static_for<0, InVectorSize, 1>{}( + [&](auto iM) { elementwise_op(in_thread_buf(iM), in_thread_buf[iM]); }); + + indices_global_load.Run(in_grid_1d_desc, + indices_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + indices_thread_buf); + + indices_global_load.MoveSrcSliceWindow(in_grid_1d_desc, loop_step_index); + + static_for<0, InVectorSize, 1>{}([&](auto iM) { + if(indices_thread_buf[iM] >= 0) + { + if constexpr(MemOp == InMemoryDataOperationEnum::Set) + { + // User should guarantee each index in p_indices_global is different + *(p_out_global + indices_thread_buf[iM]) = + ck::type_convert(in_thread_buf[iM]); + } + else if constexpr(MemOp == InMemoryDataOperationEnum::AtomicAdd) + { + atomic_add(p_out_global + indices_thread_buf[iM], + ck::type_convert(in_thread_buf[iM])); + } + else if constexpr(MemOp == InMemoryDataOperationEnum::AtomicMax) + { + atomic_max(p_out_global + indices_thread_buf[iM], + ck::type_convert(in_thread_buf[iM])); + } + else if constexpr(MemOp == InMemoryDataOperationEnum::Add) + { + // User should guarantee each index in p_indices_global is different + *(p_out_global + indices_thread_buf[iM]) += + ck::type_convert(in_thread_buf[iM]); + } + else + { + static_assert(MemOp == InMemoryDataOperationEnum::Set || + MemOp == InMemoryDataOperationEnum::AtomicAdd || + MemOp == InMemoryDataOperationEnum::AtomicMax || + MemOp == InMemoryDataOperationEnum::Add); + } + } + }); + + } while(--num_iter); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp index 901e7aee98a7b6ab1c37b7dd4dc36f796632fef6..41352fabeb361b017949738bda31649a3ce0dd40 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp index 88c7b6acfeb9c43ec0e8b5a5ab557f834430fb40..0ad36b418a42bff56eb87ba35bb5150808e9cac4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp index 0344e68305b1a951cfc41a53666090bd1e0113b6..5f56ac6fc4664be1685a950dc6e242b3e256e178 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp index ff2511fa6e61eda6c4cca976d0372cf98840ee5c..ee68660a06a2933e45ead9c05fadc674ec9d5ae7 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp similarity index 99% rename from include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp rename to include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp index 792ffabcb90c71b4620915572c29091b67c8c4d6..c3f122106df29a471e376c453d055806c9aa2514 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp similarity index 97% rename from include/ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp rename to include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp index 37795fa5694ae402af3225c11ffdbe16a6bc086f..e50fb9813325b08576f49d321e7a5de5d51bbff2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp @@ -1,10 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include "ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp" namespace ck { template +struct GridwiseNormalizationSplitK1st +{ + static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadBufferLengths_M_K = Sequence; + static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + using ThreadBufferLengths_M_1 = Sequence; + static constexpr auto thread_buffer_desc_m_1 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1)); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelford; + + using BlockwiseWelford = BlockwiseWelford; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize; + + static constexpr auto ThreadBufferNumber = Number{}; + + __device__ static int + GetKPerThread(int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id) + { + bool is_rightmost_block = block_k_cluster_id == kGridSize - 1; + + if(is_rightmost_block) + { + int left_kPerBlock = math::integer_divide_ceil(kRaw, kGridSize); + int kPerBlock = kRaw % kGridSize == 0 ? left_kPerBlock : kRaw % left_kPerBlock; + int kPerThread = + kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize); + int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize; + + if(kPerBlockTail > 0) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + int thread_max_len = + (thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i; + int delta = thread_max_len - kPerBlockTail; + delta = math::clamp(thread_max_len - kPerBlockTail, 0, XSrcVectorSize); + kPerThread += XSrcVectorSize - delta; + }); + } + + return kPerThread; + } + else + { + int kPerBlock = math::integer_divide_ceil(kRaw, kGridSize); + return KThreadSliceSize * (kPerBlock / K_BlockTileSize); + } + } + + // Calculate mean and variance by welford along k dimension + __device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k, + const MeanVarGridDesc_M_KBlock& mean_var_grid_desc_m_kblock, + index_t num_k_block_tile_iteration, + const XDataType* const __restrict__ p_x_global, + MeanVarDataType* const p_mean_global, + MeanVarDataType* const p_variance_global, + int32_t* const p_welford_count_global) + { + auto x_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + StaticBuffer + mean_thread_buf; + StaticBuffer + var_thread_buf; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const index_t k_grid_size = mean_var_grid_desc_m_kblock.GetLength(I1); + const index_t block_m_cluster_id = block_global_id / k_grid_size; + const index_t block_k_cluster_id = block_global_id % k_grid_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index( + block_m_cluster_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_k_cluster_id * reduceSizePerBlock + thread_k_cluster_id * XSrcVectorSize)); + + auto mean_var_count_store_index = make_multi_index( + block_m_cluster_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_k_cluster_id); + + auto threadwise_welford_mean_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 1, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_grid_desc_m_kblock, mean_var_count_store_index, PassThroughOp{}); + + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x_global, x_grid_desc_m_k.GetElementSpaceSize()); + + auto mean_global_val_buf = make_dynamic_buffer( + p_mean_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize()); + + auto var_global_val_buf = make_dynamic_buffer( + p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize()); + + auto threadwise_welford = ThreadwiseWelford(); + int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; + threadwise_welford.max_count_ = + GetKPerThread(kRaw, k_grid_size, block_k_cluster_id, thread_k_cluster_id); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + }); + + for(index_t k = 0; k < num_k_block_tile_iteration; ++k) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf); + }); + } + + int welford_count = 0; + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + int count = threadwise_welford.cur_count_; + BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); + + // The value of count is same for all I + if constexpr(I == MThreadSliceSize - 1) + welford_count = count; + }); + + if(thread_k_cluster_id == 0) + { + threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + mean_thread_buf, + mean_var_grid_desc_m_kblock, + mean_global_val_buf); + + threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + var_thread_buf, + mean_var_grid_desc_m_kblock, + var_global_val_buf); + + if(block_m_cluster_id == 0 && thread_m_cluster_id == 0) + p_welford_count_global[block_k_cluster_id] = welford_count; + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..136ac94e7f0fabb81cfd5baea7e58174799f4d43 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp @@ -0,0 +1,418 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +struct GridwiseNormalizationSplitK2nd +{ + static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) || + (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert(XSrcVectorSize == YDstVectorSize); + static_assert(XSrcVectorSize == GammaSrcVectorSize); + static_assert(XSrcVectorSize == BetaSrcVectorSize); + + static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadBufferLengths_M_K = Sequence; + static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + using ThreadBufferLengths_M_1 = Sequence; + static constexpr auto thread_buffer_desc_m_1 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1)); + + using ThreadWelfordSrcDesc_M_1 = decltype(thread_buffer_desc_m_1); + using ThreadWelfordDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelfordMerge; + + using BlockwiseWelford = BlockwiseWelford; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize; + + static constexpr auto ThreadBufferNumber = Number{}; + + __device__ static void Run(const MeanVarGridDesc_M_KBlock& mean_var_grid_desc_m_kblock, + const CountGridDesc_M_KBlock& count_grid_desc_m_kblock, + const XYGammaBetaGridDesc_M_K& x_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K& gamma_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K& beta_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K& y_grid_desc_m_k, + index_t num_k_mean_var_count_iteration, + index_t num_k_block_tile_iteration, + index_t k_grid_size, + ComputeDataType epsilon, + const MeanVarDataType* const p_mean_global, + const MeanVarDataType* const p_variance_global, + const int32_t* const p_welford_count_global, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + const YElementwiseOperation y_elementwise_op) + { + // Thread/Block id + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t block_m_cluster_id = block_global_id / k_grid_size; + const index_t block_k_cluster_id = block_global_id % k_grid_size; + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + // Global Memory + const auto mean_global_val_buf = make_dynamic_buffer( + p_mean_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize()); + + const auto var_global_val_buf = make_dynamic_buffer( + p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize()); + + const auto welford_count_global_val_buf = make_dynamic_buffer( + p_welford_count_global, count_grid_desc_m_kblock.GetElementSpaceSize()); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x_global, x_grid_desc_m_k.GetElementSpaceSize()); + + const auto gamma_global_val_buf = make_dynamic_buffer( + p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize()); + + const auto beta_global_val_buf = make_dynamic_buffer( + p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize()); + + auto y_global_val_buf = make_dynamic_buffer( + p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); + + // VGPR + StaticBuffer + in_mean_thread_buf; + StaticBuffer + in_var_thread_buf; + StaticBuffer + in_welford_count_thread_buf; + StaticBuffer + mean_thread_buf; + StaticBuffer + var_thread_buf; + StaticBuffer + welford_count_thread_buf; + + auto x_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + auto gamma_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + auto& beta_thread_buf = gamma_thread_buf; + auto& y_thread_buf = x_thread_buf; + + // IO + auto threadwise_mean_var_load_m_kblock = + ThreadwiseTensorSliceTransfer_v2, + 1, + 1, + 1, + true>( + mean_var_grid_desc_m_kblock, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id)); + + auto threadwise_count_load_m_kblock = + ThreadwiseTensorSliceTransfer_v2, + 1, + 1, + 1, + true>( + count_grid_desc_m_kblock, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id)); + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration + + thread_k_cluster_id * XSrcVectorSize)); + + auto threadwise_gamma_load = + ThreadwiseTensorSliceTransfer_v2( + gamma_grid_desc_m_k, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration + + thread_k_cluster_id * GammaSrcVectorSize)); + + auto threadwise_beta_load = + ThreadwiseTensorSliceTransfer_v2( + beta_grid_desc_m_k, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration + + thread_k_cluster_id * BetaSrcVectorSize)); + + auto threadwise_y_store = + ThreadwiseTensorSliceTransfer_v1r3( + y_grid_desc_m_k, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration + + thread_k_cluster_id * YDstVectorSize), + y_elementwise_op); + + // step1: Merge mean and variance + constexpr auto mean_var_count_thread_copy_step_I0_k = + make_multi_index(I0, KThreadClusterSize); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + welford_count_thread_buf(I) = 0; + }); + + for(index_t k = 0; k < num_k_mean_var_count_iteration; ++k) + { + threadwise_mean_var_load_m_kblock.Run(mean_var_grid_desc_m_kblock, + mean_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_mean_thread_buf); + + threadwise_mean_var_load_m_kblock.Run(mean_var_grid_desc_m_kblock, + var_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_var_thread_buf); + + threadwise_count_load_m_kblock.Run(count_grid_desc_m_kblock, + welford_count_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_welford_count_thread_buf); + + ThreadwiseWelford::Run(in_mean_thread_buf, + in_var_thread_buf, + in_welford_count_thread_buf, + mean_thread_buf, + var_thread_buf, + welford_count_thread_buf); + + threadwise_mean_var_load_m_kblock.MoveSrcSliceWindow( + mean_var_grid_desc_m_kblock, mean_var_count_thread_copy_step_I0_k); + threadwise_count_load_m_kblock.MoveSrcSliceWindow(count_grid_desc_m_kblock, + mean_var_count_thread_copy_step_I0_k); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseWelford::Run( + mean_thread_buf(I), var_thread_buf(I), welford_count_thread_buf(I)); + }); + + // step2: normalization + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); + + for(index_t k = 0; k < num_k_block_tile_iteration; ++k) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf(i)); + + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + divisor; + + // gamma + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_beta_load.Run(beta_grid_desc_m_k, + beta_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + beta_thread_buf(i)); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf(i), + y_grid_desc_m_k, + y_global_val_buf); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k); + }); + } // end for (normalization) + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp similarity index 99% rename from include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp rename to include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp index 3a7ae459e5f6e8d58edc49ae53c4309030e911f9..ff9712276c26d14c9a68f9df618ed4666f2dc5e9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp b/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp index 188c62d93b0d60317ac139115b79f496612d2ea6..c6eecc067d9f77fa488e7a922fe1fb711c0c0d0f 100644 --- a/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp index 94cdfe01087bc4005d45f11992a6a8f94d42691c..44730d551c5c8c9fb2408502ece0d45c6f37a221 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp index e045e3b545a9897f5df9df1948123e86424d7c7b..e97aa433a6f3b9d54fe76ed26b5b0f6bd439ae84 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP #define CK_THREADWISE_GEMM_DLOPS_V3_HPP diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp index 0a1197a1630c12c53afa9513f10fcfa907ba82d3..6774a35bcb872466d272da0bccd618709d3fb29b 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 16484ddcc5a89792c1331a48516e8317b36f2633..605f2569c6622bb589d46da71020ea46ddba69f7 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index bb28c194f4b7a2f371d2be00cb00fc4fe3c17d3f..6665d765f81ccd45f36598b104711e77368b6ee5 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -1,11 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor/static_tensor.hpp" @@ -207,15 +208,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 auto src_vector_container = src_vector_type{ src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; - // apply SrcElementwiseOperation on src_vector_container - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - SrcData src_v; - - src_element_op_(src_v, src_vector_container.template AsType()[i]); - - src_vector_container.template AsType()(i) = src_v; - }); - // copy data from src_vector_container into src_thread_scratch_ src_thread_scratch_tuple_(thread_scratch_id) .template SetAsType( @@ -318,7 +310,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 constexpr auto data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); - // TODO type_convert is not used yet!!!!! using src_vector_t = vector_type_maker_t; using dst_vector_t = vector_type_maker_t; @@ -342,19 +333,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1 Number{}); // do data transpose - // TODO type_convert is not used yet!!!!! transpose_vectors{}( src_vector_refs, dst_vector_refs); }); } - else - { - static_ford{}([&](auto idx) { - // convert from SrcData to DstData here - dst_thread_scratch_(idx) = - type_convert(src_thread_scratch_tuple_[thread_scratch_id][idx]); - }); - } + + static_ford{}([&](auto idx) { + // apply the src elementwise op and convert to DstData under the hood if needed + DstData dst_v; + src_element_op_(dst_v, src_thread_scratch_tuple_[thread_scratch_id][idx]); + dst_thread_scratch_(idx) = dst_v; + }); #endif } diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp deleted file mode 100644 index 6a73466efa443af294f2062ee5173cf548880c3d..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp +++ /dev/null @@ -1,886 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R3_HPP -#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R3_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "static_tensor.hpp" - -namespace ck { - -namespace detail { -// TODO: How to fix this? It uses an struct instead of lambda because lambda -// doesn't have constructor -template -struct lambda_scalar_per_access_for_src_and_dst -{ - __host__ __device__ constexpr auto operator()(index_t i) const - { - if(i == SrcVectorDim && i == DstVectorDim) - { - return math::lcm(SrcScalarPerVector, DstScalarPerVector); - } - else if(i == SrcVectorDim) - { - return SrcScalarPerVector; - } - else if(i == DstVectorDim) - { - return DstScalarPerVector; - } - else - { - return 1; - } - } -}; - -} // namespace detail - -// Assume: -// 1. src_desc and dst_desc are not known at compile-time -// 2. SrcBuffer and DstBuffer are DynamicBuffer -// 3. src_slice_origin and dst_slice_origin are not known at compile-time, -// 4. Use thread buffer -template // control whether to move back dst coordinate after each - // RunWrite(), will be fused with MoveDstSliceWindow to - // save addr computation -struct ThreadwiseTensorSliceTransfer_v3r3 -{ - static constexpr index_t nDim = SliceLengths::Size(); - using Index = MultiIndex; - - using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); - using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - using Dst0Coord = decltype(make_tensor_coordinate(Dst0Desc{}, Index{})); - using Dst1Coord = decltype(make_tensor_coordinate(Dst1Desc{}, Index{})); - - using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); - using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); - using Dst0CoordStep = decltype(make_tensor_coordinate_step(Dst0Desc{}, Index{})); - using Dst1CoordStep = decltype(make_tensor_coordinate_step(Dst1Desc{}, Index{})); - - __device__ constexpr ThreadwiseTensorSliceTransfer_v3r3( - const SrcDesc& src_desc, - const Index& src_slice_origin, - const SrcElementwiseOperation& src_element_op, - const DstDesc& dst_desc, - const Dst0Desc& dst0_desc, - const Dst1Desc& dst1_desc, - const Index& dst_slice_origin, - const DstElementwiseOperation& dst_element_op) - : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), - dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), - dst0_coord_(make_tensor_coordinate(dst0_desc, dst_slice_origin)), - dst1_coord_(make_tensor_coordinate(dst1_desc, dst_slice_origin)), - src_element_op_(src_element_op), - dst_element_op_(dst_element_op) - { - } - - __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) - { - src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); - } - - __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, - const Dst0Desc& dst0_desc, - const Dst1Desc& dst1_desc, - const Index& dst_slice_origin_idx) - { - dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); - dst0_coord_ = make_tensor_coordinate(dst0_desc, dst_slice_origin_idx); - dst1_coord_ = make_tensor_coordinate(dst1_desc, dst_slice_origin_idx); - } - - template - __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) - { - static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or - SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, - "wrong!"); - - static_assert( - is_same, remove_cvref_t>::value, - "wrong! SrcBuffer and SrcData data type are inconsistent"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_dim_access_order = SrcDimAccessOrder{}; - - constexpr auto ordered_src_access_lengths = - container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - - // make forward steps - const auto src_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(src_desc, forward_step_idx); - }, - Number{}); - - // make backward steps - const auto src_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(src_desc, backward_step_idx); - }, - Number{}); - - // loop over tensor and copy - static_ford{}([&](auto ordered_src_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_src_access_idx[I0]; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate src data index - constexpr auto src_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] - : ordered_src_access_lengths[i] - 1 - - ordered_src_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * - src_scalar_per_access; - }(); - - constexpr auto src_data_idx_seq = generate_sequence_v2( - [&](auto i) { return Number{}; }, Number{}); - - const bool is_src_valid = - coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); - - using src_vector_type = vector_type_maker_t; - using src_vector_t = typename src_vector_type::type; - - // copy data from src_buf into src_vector_container - auto src_vector_container = src_vector_type{ - src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; - - // apply SrcElementwiseOperation on src_vector_container - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - src_vector_container.template AsType()(i) = - src_element_op_(src_vector_container.template AsType()[i]); - }); - - // copy data from src_vector_container into src_thread_scratch_ - src_thread_scratch_.template SetAsType( - src_data_idx_seq, src_vector_container.template AsType()[I0]); - - constexpr auto move_on_dim = [&]() constexpr - { - StaticallyIndexedArray move_on_dim_; - - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= - ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; - }); - }); - - return move_on_dim_; - } - (); - - // move src coord - static_for<0, nDim, 1>{}([&](auto i) { - if constexpr(move_on_dim[i]) - { - if constexpr(forward_sweep[i]) - { - move_tensor_coordinate( - src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); - } - else - { - move_tensor_coordinate( - src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); - } - } - }); - }); - - // move src coordinate back to slice origin (or not) - if constexpr(SrcResetCoordinateAfterRun) - { - const auto src_reset_step = - make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); - - move_tensor_coordinate(src_desc, src_coord_, src_reset_step); - } - } - - __device__ void TransferDataFromSrcThreadScratchToDstThreadScratch() - { -#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE - static_ford{}([&](auto idx) { - // convert from SrcData to DstData here - dst_thread_scratch_(idx) = type_convert(src_thread_scratch_[idx]); - }); -#else - // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ - // TODO make this logic more generic for more sub-dword datatype - if constexpr(SrcVectorDim != DstVectorDim && - is_same>::value && - is_same>::value && - SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) - { - // each transpose does - // DstScalarPerVector # of src vectors in src_thread_scratch_ - // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ - constexpr index_t num_src_vector = Number{}; - constexpr index_t num_dst_vector = Number{}; - - // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose - // TODO: make this logic generic for all scenario - static_assert(SrcVectorDim != DstVectorDim, "wrong"); - - constexpr auto src_scalar_step_in_vector = generate_sequence( - detail::lambda_scalar_step_in_vector{}, Number{}); - - constexpr auto dst_scalar_step_in_vector = generate_sequence( - detail::lambda_scalar_step_in_vector{}, Number{}); - - constexpr auto scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access_for_src_and_dst{}, - Number{}); - - constexpr auto access_lengths = SliceLengths{} / scalar_per_access; - - static_ford{}([&](auto access_idx) { - constexpr auto data_idx = access_idx * scalar_per_access; - - constexpr auto data_idx_seq = generate_sequence_v2( - [&](auto i) { return Number{}; }, Number{}); - - // TODO type_convert is not used yet!!!!! - using src_vector_t = vector_type_maker_t; - using dst_vector_t = vector_type_maker_t; - - // get DstScalarPerVector # of read-only references to src vectors from - // src_thread_scratch_ - const auto src_vector_refs = generate_tie( - [&](auto i) -> const src_vector_t& { - // i increment corresponds to movement in DstVectorDim - return src_thread_scratch_.GetVectorTypeReference( - data_idx_seq + i * dst_scalar_step_in_vector); - }, - Number{}); - - // get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_ - auto dst_vector_refs = generate_tie( - [&](auto i) -> dst_vector_t& { - // i increment corresponds to movement in SrcVectorDim - return dst_thread_scratch_.GetVectorTypeReference( - data_idx_seq + i * src_scalar_step_in_vector); - }, - Number{}); - - // do data transpose - // TODO type_convert is not used yet!!!!! - transpose_vectors{}( - src_vector_refs, dst_vector_refs); - }); - } - else - { - static_ford{}([&](auto idx) { - // convert from SrcData to DstData here - dst_thread_scratch_(idx) = type_convert(src_thread_scratch_[idx]); - }); - } -#endif - } - - template - __device__ void RunWrite(const DstDesc& dst_desc, - DstBuffer& dst_buf, - const Dst0Desc& dst0_desc, - const Dst0Buffer& dst0_buf, - const Dst1Desc& dst1_desc, - const Dst1Buffer& dst1_buf) - { - // if there is transpose, it's done here - // TODO move this elsewhere - TransferDataFromSrcThreadScratchToDstThreadScratch(); - - static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or - DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, - "wrong!"); - - static_assert( - is_same, remove_cvref_t>::value, - "wrong! SrcBuffer or DstBuffer data type is wrong"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - - // src scalar per access on each dim - // TODO: don't use this - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_dim_access_order = DstDimAccessOrder{}; - - constexpr auto ordered_dst_access_lengths = - container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); - - // make forward steps - const auto dst_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(dst_desc, forward_step_idx); - }, - Number{}); - - // make forward steps: dst0 - // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same - // DstScalarPerVector - // TODO: fix this - const auto dst0_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(dst0_desc, forward_step_idx); - }, - Number{}); - - // make forward steps: dst1 - // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same - // DstScalarPerVector - // TODO: fix this - const auto dst1_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(dst1_desc, forward_step_idx); - }, - Number{}); - - // make backward steps - const auto dst_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(dst_desc, backward_step_idx); - }, - Number{}); - - // make backward steps: dst0 - // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same - // DstScalarPerVector - // TODO: fix this - const auto dst0_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(dst0_desc, backward_step_idx); - }, - Number{}); - - // make backward steps: dst1 - // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same - // DstScalarPerVector - // TODO: fix this - const auto dst1_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(dst1_desc, backward_step_idx); - }, - Number{}); - - // loop over tensor and copy - static_ford{}([&](auto ordered_dst_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_dst_access_idx[I0]; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate dst data index - constexpr auto dst_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] - : ordered_dst_access_lengths[i] - 1 - - ordered_dst_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * - dst_scalar_per_access; - }(); - - constexpr auto dst_data_idx_seq = generate_sequence_v2( - [&](auto i) { return Number{}; }, Number{}); - - const bool is_dst_valid = - coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); - - using dst_vector_type = vector_type_maker_t; - using dst_vector_t = typename dst_vector_type::type; - - // copy data from dst_thread_scratch_ into dst_vector_container - auto dst_vector_container = dst_vector_type{ - dst_thread_scratch_.template GetAsType(dst_data_idx_seq)}; - - // apply DstElementwiseOperation on dst_vector_container - static_for<0, DstScalarPerVector, 1>{}([&](auto i) { - dst_vector_container.template AsType()(i) = - dst_element_op_(dst_vector_container.template AsType()[i]); - }); - - // copy data from dst_vector_container to dst_buf - dst_buf.template Set( - dst_coord_.GetOffset(), - is_dst_valid, - dst_vector_container.template AsType()[I0]); - - constexpr auto move_on_dim = [&]() constexpr - { - StaticallyIndexedArray move_on_dim_; - - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= - ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; - }); - }); - - return move_on_dim_; - } - (); - - // move dst coord - static_for<0, nDim, 1>{}([&](auto i) { - if constexpr(move_on_dim[i]) - { - if constexpr(forward_sweep[i]) - { - move_tensor_coordinate( - dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); - } - else - { - move_tensor_coordinate( - dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); - } - } - }); - }); - - // move dst coordinate back to slice origin (or not) - if constexpr(DstResetCoordinateAfterRun) - { - const auto dst_reset_step = - make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); - - move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); - } - } - - __device__ static constexpr auto GetSrcCoordinateResetStep() - { - constexpr auto I0 = Number<0>{}; - - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_dim_access_order = SrcDimAccessOrder{}; - - constexpr auto ordered_src_access_lengths = - container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - - // judge move forward or move backward during the last iteration - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - // TODO: BUG: should start at 1 - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_src_access_lengths[I0] - 1; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate src data index after last iteration in RunRead(), if it has not being reset by - // RunRead() - constexpr auto src_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; - }); - - return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * - src_scalar_per_access; - }(); - - // - constexpr auto reset_src_data_step = [&]() { - Index reset_src_data_step_; - - static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); - - return reset_src_data_step_; - }(); - - return reset_src_data_step; - } - - __device__ static constexpr auto GetDstCoordinateResetStep() - { - constexpr auto I0 = Number<0>{}; - - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_dim_access_order = DstDimAccessOrder{}; - - constexpr auto ordered_dst_access_lengths = - container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); - - // judge move forward or move backward during the last iteration - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_dst_access_lengths[I0] - 1; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate dst data index after last iteration in RunWrite(), if it has not being reset by - // RunWrite() - constexpr auto dst_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; - }); - - return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * - dst_scalar_per_access; - }(); - - // - constexpr auto reset_dst_data_step = [&]() { - Index reset_dst_data_step_; - - static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); - - return reset_dst_data_step_; - }(); - - return reset_dst_data_step; - } - - // src_slice_origin_step_idx need to be known at compile-time, for performance reason - __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, - const Index& src_slice_origin_step_idx) - { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = - SrcResetCoordinateAfterRun ? src_slice_origin_step_idx - : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); - - move_tensor_coordinate(src_desc, src_coord_, adjusted_step); - } - - // src_slice_origin_step_idx need to be known at compile-time, for performance reason - __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, - const Index& src_slice_origin_step_idx) - { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = - SrcResetCoordinateAfterRun ? src_slice_origin_step_idx - : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); - - move_tensor_coordinate(src_desc, src_coord_, adjusted_step); - } - - // dst_slice_origin_step_idx need to be known at compile-time, for performance reason - __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, - const Dst0Desc dst0_desc, - const Dst1Desc dst1_desc, - const Index& dst_slice_origin_step_idx) - { - // if dst coord was not reset by RunWrite(), then need to adjust the step here - const auto adjusted_step_idx = - DstResetCoordinateAfterRun ? dst_slice_origin_step_idx - : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); - - move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); - move_tensor_coordinate(dst0_desc, dst0_coord_, adjusted_step); - move_tensor_coordinate(dst1_desc, dst1_coord_, adjusted_step); - } - - __device__ static constexpr auto GetSrcThreadScratchDescriptor() - { - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(src_access_lengths), Number{}); - - // 1st stage of transforms - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(src_access_lengths_and_vector_length[i], - src_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(src_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); - } - - __device__ static constexpr auto GetDstThreadScratchDescriptor() - { - // 1st stage of transforms - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(dst_access_lengths), Number{}); - - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(dst_access_lengths_and_vector_length[i], - dst_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); - } - - private: - static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; - static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; - - StaticTensorTupleOfVectorBuffer - src_thread_scratch_; - - StaticTensorTupleOfVectorBuffer - dst_thread_scratch_; - - SrcCoord src_coord_; - DstCoord dst_coord_; - const SrcElementwiseOperation src_element_op_; - const DstElementwiseOperation dst_element_op_; -}; - -} // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp index 6e8a23930bbb91677ee18bab216af6c45de72e4c..6a6c1f2ac5bfed7f02c5f5895d16002226cff56c 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp index f13da341f9b2d36d992b1cd1835b43c89dcabd1f..bd01108b03ca82d4877e19806521efda8da8c6c3 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp index 9c91cd9ca8f86df37fee6dfdf597be1cefe2d683..6ec9abc41708735011acda7fcd01e54cbec5c4a6 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp index 68bc2726f4be6fa1b08fe9213e0db9440526f975..cf2c7a2aee3d3f612fd12c0868477bfdfe49a2ea 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp index 0f5fb88b04540dbf355fb0f02998303b7680c193..b5847e51b42ea5b567133f0307b319fa3d8fadbb 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp index 2eb1b0ee90a446ee3b14354aff018275cf7809f4..db7dee21992d0e4edaa2e7a443b5c0ed68589107 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp index 12ba2c5381311ab45ffad4be793651d176f5fc59..eb6715e8ebb382e438791f60193bfc23f06d1773 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 0672bf8e5b2e15421f90bae642800a1b7ffcc0d3..979f3567e987ed66fd45af5deeae7bb2ac3fc761 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -262,12 +262,12 @@ struct wmma_type + class FloatC, + bool neg_a = false, + bool neg_b = false, + bool clamp = false> __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { if constexpr(wave_size == 32) diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 4d53f0d81620141ff6dfc6cb965209709341aaf8..faaa2c5a9537537a7d2367c87ef9ff66f8028318 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -27,6 +27,8 @@ enum struct MfmaInstr mfma_f32_16x16x8bf16, mfma_i32_32x32x8i8, mfma_i32_16x16x16i8, + mfma_i32_32x32x16i8, + mfma_i32_16x16x32i8, mfma_f64_16x16x4f64 }; @@ -386,6 +388,50 @@ struct mfma_type } }; +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_i32_32x32x16i8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_i32_16x16x32i8::Run(a, b, reg_c); + } +}; + template <> struct mfma_type { @@ -524,17 +570,29 @@ struct MfmaSelector #endif } +#if defined(CK_USE_AMD_MFMA_GFX940) + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_i32_32x32x16i8; + } + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_i32_16x16x32i8; + } +#else template <> static constexpr auto GetMfma() { return MfmaInstr::mfma_i32_32x32x8i8; } - template <> static constexpr auto GetMfma() { return MfmaInstr::mfma_i32_16x16x16i8; } +#endif static constexpr auto selected_mfma = mfma_type()>{}; diff --git a/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp index 5fc11d9158ab9a1261c1285c2cd9f5208d7afbf3..ea27a40ce3c7a48564af112a97af342c157f746d 100644 --- a/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp index 13d0a28cfe5b2db14a83a4b81e0bb379aabbe137..59f64bde77a0b2f3fb1b022a8951b3fecb91442f 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -13,6 +13,61 @@ namespace ck { namespace tensor_operation { +namespace { +template < + index_t NDimSpatial, + typename ALayout, + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization> +constexpr auto +make_out_n_ho_wo_k_grid_desc(const index_t N, + const index_t Ho, + const index_t Wo, + const index_t K, + const std::array& out_g_n_k_wos_strides) +{ + + if constexpr(is_same_v) + { + const index_t NStride = out_g_n_k_wos_strides[1]; + const index_t HiStride = out_g_n_k_wos_strides[3]; + const index_t WiStride = out_g_n_k_wos_strides[4]; + const auto CStride = Number<1>{}; + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + + return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K), + make_tuple(WiStride, CStride)); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N, Ho, Wo, K), + make_tuple(NStride, HiStride, WiStride, CStride)); + } + } + else if constexpr(is_same_v) + { + // assume packed + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + } + else + { + return make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); + } + } + else + { + throw std::runtime_error("wrong! unsupported layout: " + ALayout::name()); + } +} + +} // namespace + template < index_t NDimSpatial, ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization, @@ -29,11 +84,12 @@ struct TransformConvBwdDataToGemm_v1 template , + (is_same_v || + is_same_v), bool>::type = false> static auto MakeADescriptor_AK0_M_AK1( const std::array& out_g_n_k_wos_lengths, - const std::array& /* out_g_n_k_wos_strides */, + const std::array& out_g_n_k_wos_strides, const std::array& wei_g_k_c_xs_lengths, const std::array& /* wei_g_k_c_xs_strides */, const std::array& in_g_n_c_wis_lengths, @@ -70,9 +126,9 @@ struct TransformConvBwdDataToGemm_v1 const index_t AK0 = K / AK1; - // assume packed const auto out_n_ho_wo_k_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); + make_out_n_ho_wo_k_grid_desc( + N, Ho, Wo, K, out_g_n_k_wos_strides); if constexpr(ConvBwdDataSpecialization == ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: @@ -80,7 +136,7 @@ struct TransformConvBwdDataToGemm_v1 { // A: output tensor const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + out_n_ho_wo_k_grid_desc, make_tuple(make_pass_through_transform(N * Ho * Wo), make_unmerge_transform(make_tuple(AK0, AK1))), make_tuple(Sequence<0>{}, Sequence<1>{}), diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index 1b5e64b66cf292fbcf4979d95aa4a5b1fa3e3a8b..cee3d2825b13facc912cdc8de21774ce793e9a2c 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/amd_address_space.hpp b/include/ck/utility/amd_address_space.hpp index 9f1525914cdee16ee6cbed51297f0008d46f6085..d54f70e750e8032f706ae3f7f0dfec9f7cab4c8d 100644 --- a/include/ck/utility/amd_address_space.hpp +++ b/include/ck/utility/amd_address_space.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index bdfb4f27580e011f0bebabfeae9146374a5ecd67..38ee76d8836d15a85be36e601f574c4f3ec8e539 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "data_type.hpp" @@ -286,7 +286,22 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, int soffset, // dst_wave_addr_offset int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); -template +// memory coherency bit for buffer store/load instruction +// check ISA manual for each GFX target +// e.g. for +// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf, +// page 67~68 +enum struct AmdBufferCoherenceEnum +{ + DefaultCoherence = 0, // default value + GLC = 1, + SLC = 2, + GLC_SLC = 3, +}; + +template __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset) @@ -305,28 +320,37 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w // use fp32 load to mimic fp64 load if constexpr(N == 1) { - const float2_t tmp = llvm_amdgcn_raw_buffer_load_fp32x2( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + const float2_t tmp = + llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); return bit_cast(tmp); } else if constexpr(N == 2) { - const float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + const float4_t tmp = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); return bit_cast(tmp); } else if constexpr(N == 4) { - const float4_t f32_0 = llvm_amdgcn_raw_buffer_load_fp32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + const float4_t f32_0 = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); const float4_t f32_1 = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 4 * sizeof(float), - 0); + static_cast(coherence)); vector_type tmp; tmp.AsType()(Number<0>{}) = bit_cast(f32_0); @@ -339,31 +363,40 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w { if constexpr(N == 1) { - return llvm_amdgcn_raw_buffer_load_fp32( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_fp32(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 2) { - return llvm_amdgcn_raw_buffer_load_fp32x2( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 4) { - return llvm_amdgcn_raw_buffer_load_fp32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 8) { vector_type tmp; - tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + tmp.AsType()(Number<0>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); tmp.AsType()(Number<1>{}) = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 4 * sizeof(float), - 0); + static_cast(coherence)); return tmp.AsType()(Number<0>{}); } @@ -372,24 +405,32 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w { if constexpr(N == 1) { - return llvm_amdgcn_raw_buffer_load_fp16( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_fp16(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 2) { - return llvm_amdgcn_raw_buffer_load_fp16x2( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_fp16x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 4) { - return llvm_amdgcn_raw_buffer_load_fp16x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 8) { // use fp32 load to mimic fp16 load - float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); return bit_cast(tmp); } @@ -398,23 +439,31 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w { if constexpr(N == 1) { - return llvm_amdgcn_raw_buffer_load_i16( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 2) { - return llvm_amdgcn_raw_buffer_load_i16x2( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_i16x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 4) { - return llvm_amdgcn_raw_buffer_load_i16x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_i16x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 8) { - int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); return bit_cast(tmp); } @@ -423,31 +472,40 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w { if constexpr(N == 1) { - return llvm_amdgcn_raw_buffer_load_i32( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 2) { - return llvm_amdgcn_raw_buffer_load_i32x2( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 4) { - return llvm_amdgcn_raw_buffer_load_i32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 8) { vector_type tmp; - tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + tmp.AsType()(Number<0>{}) = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); tmp.AsType()(Number<1>{}) = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 4 * sizeof(int32_t), - 0); + static_cast(coherence)); return tmp.AsType()(Number<0>{}); } } @@ -455,17 +513,23 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w { if constexpr(N == 1) { - return llvm_amdgcn_raw_buffer_load_i8( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 2) { #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - return llvm_amdgcn_raw_buffer_load_i8x2( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_i8x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); #else - int16_t tmp = llvm_amdgcn_raw_buffer_load_i16( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); return bit_cast(tmp); #endif @@ -473,11 +537,15 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w else if constexpr(N == 4) { #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - return llvm_amdgcn_raw_buffer_load_i8x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); #else - int32_t tmp = llvm_amdgcn_raw_buffer_load_i32( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); return bit_cast(tmp); #endif @@ -487,19 +555,24 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE vector_type tmp; - tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + tmp.AsType()(Number<0>{}) = + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); tmp.AsType()(Number<1>{}) = llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 4 * sizeof(int8_t), - 0); + static_cast(coherence)); return tmp.AsType()(Number<0>{}); #else - int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); return bit_cast(tmp); #endif @@ -509,31 +582,36 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE vector_type tmp; - tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + tmp.AsType()(Number<0>{}) = + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); tmp.AsType()(Number<1>{}) = llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 4 * sizeof(int8_t), - 0); + static_cast(coherence)); tmp.AsType()(Number<2>{}) = llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 8 * sizeof(int8_t), - 0); + static_cast(coherence)); tmp.AsType()(Number<3>{}) = llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 12 * sizeof(int8_t), - 0); + static_cast(coherence)); return tmp.AsType()(Number<0>{}); #else - int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); return bit_cast(tmp); #endif @@ -541,7 +619,9 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w } } -template +template __device__ void amd_buffer_store_impl(const typename vector_type::type src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, @@ -565,7 +645,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 2) { @@ -573,7 +653,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } } else if constexpr(is_same::value) @@ -584,7 +664,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 2) { @@ -592,7 +672,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 4) { @@ -600,7 +680,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } } else if constexpr(is_same::value) @@ -611,7 +691,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 2) { @@ -619,7 +699,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 4) { @@ -627,7 +707,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 8) { @@ -638,19 +718,19 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<1>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + 4 * sizeof(half_t), - 0); + static_cast(coherence)); #else llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); #endif } } @@ -662,7 +742,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 2) { @@ -670,7 +750,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 4) { @@ -678,7 +758,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 8) { @@ -688,13 +768,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType()[Number<1>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + 4 * sizeof(bhalf_t), - 0); + static_cast(coherence)); } } else if constexpr(is_same::value) @@ -705,7 +785,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 2) { @@ -713,7 +793,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 4) { @@ -721,7 +801,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } } else if constexpr(is_same::value) @@ -732,7 +812,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 2) { @@ -741,13 +821,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); #else llvm_amdgcn_raw_buffer_store_i16(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); #endif } else if constexpr(N == 4) @@ -757,13 +837,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); #else llvm_amdgcn_raw_buffer_store_i32(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); #endif } else if constexpr(N == 8) @@ -772,7 +852,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 16) { @@ -780,7 +860,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } } } @@ -1012,7 +1092,9 @@ __device__ void amd_buffer_atomic_max_impl(const typename vector_type::typ // 1) p_src_wave must point to global memory space // 2) p_src_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. -template +template __device__ typename vector_type_maker::type::type amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, index_t src_thread_element_offset, @@ -1032,10 +1114,10 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; - return amd_buffer_load_impl( + return amd_buffer_load_impl( src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); #else - vector_t tmp = amd_buffer_load_impl( + vector_t tmp = amd_buffer_load_impl( src_wave_buffer_resource, src_thread_addr_offset, 0); return src_thread_element_valid ? tmp : vector_t(0); @@ -1046,7 +1128,9 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, // 1) p_src_wave must point to global memory space // 2) p_src_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. -template +template __device__ typename vector_type_maker::type::type amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, index_t src_thread_element_offset, @@ -1064,7 +1148,7 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, constexpr index_t vector_size = scalar_type::vector_size; - vector_t tmp = amd_buffer_load_impl( + vector_t tmp = amd_buffer_load_impl( src_wave_buffer_resource, src_thread_addr_offset, 0); return src_thread_element_valid ? tmp : vector_t(customized_value); @@ -1074,7 +1158,9 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, // 1) p_dst_wave must point to global memory // 2) p_dst_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. -template +template __device__ void amd_buffer_store(const typename vector_type_maker::type::type src_thread_data, T* p_dst_wave, const index_t dst_thread_element_offset, @@ -1093,12 +1179,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker::type::t #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; - amd_buffer_store_impl( + amd_buffer_store_impl( src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); #else if(dst_thread_element_valid) { - amd_buffer_store_impl( + amd_buffer_store_impl( src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); } #endif diff --git a/include/ck/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp index 1f7df70bcd565dcba4d0aee3a49aca32f835bb45..43baa817d3672c0b2eed6059a88351aeb19a5209 100644 --- a/include/ck/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_AMD_INLINE_ASM_HPP #define CK_AMD_INLINE_ASM_HPP diff --git a/include/ck/utility/amd_llvm_intrinsic.hpp b/include/ck/utility/amd_llvm_intrinsic.hpp deleted file mode 100644 index 01e77d7be897085b8c5816fe9e20fe924b859ba7..0000000000000000000000000000000000000000 --- a/include/ck/utility/amd_llvm_intrinsic.hpp +++ /dev/null @@ -1,14 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_AMD_LLVM_INTRINSIC_HPP -#define CK_AMD_LLVM_INTRINSIC_HPP - -#include "data_type.hpp" - -namespace ck { - -__device__ int32_t llvm_amdgcn_readfirstlane_i32(int32_t i) __asm("llvm.amdgcn.readfirstlane"); - -} // namespace ck -#endif diff --git a/include/ck/utility/amd_wave_read_first_lane.hpp b/include/ck/utility/amd_wave_read_first_lane.hpp new file mode 100644 index 0000000000000000000000000000000000000000..741b2975af6c5bf99346b1460018eac6fa33b21b --- /dev/null +++ b/include/ck/utility/amd_wave_read_first_lane.hpp @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/functional2.hpp" +#include "ck/utility/math.hpp" + +#include +#include +#include +#include + +namespace ck { +namespace detail { + +template +struct get_carrier; + +template <> +struct get_carrier<1> +{ + using type = uint8_t; +}; + +template <> +struct get_carrier<2> +{ + using type = uint16_t; +}; + +template <> +struct get_carrier<3> +{ + using type = class carrier + { + using value_type = uint32_t; + + std::array bytes; + static_assert(sizeof(bytes) <= sizeof(value_type)); + + // replacement of host std::copy_n() + template + __device__ static OutputIterator copy_n(InputIterator from, Size size, OutputIterator to) + { + if(0 < size) + { + *to = *from; + ++to; + for(Size count = 1; count < size; ++count) + { + *to = *++from; + ++to; + } + } + + return to; + } + + // method to trigger template substitution failure + __device__ carrier(const carrier& other) noexcept + { + copy_n(other.bytes.begin(), bytes.size(), bytes.begin()); + } + + public: + __device__ carrier& operator=(value_type value) noexcept + { + copy_n(reinterpret_cast(&value), bytes.size(), bytes.begin()); + + return *this; + } + + __device__ operator value_type() const noexcept + { + std::byte result[sizeof(value_type)]; + + copy_n(bytes.begin(), bytes.size(), result); + + return *reinterpret_cast(result); + } + }; +}; +static_assert(sizeof(get_carrier<3>::type) == 3); + +template <> +struct get_carrier<4> +{ + using type = uint32_t; +}; + +template +using get_carrier_t = typename get_carrier::type; + +} // namespace detail + +__device__ inline int32_t amd_wave_read_first_lane(int32_t value) +{ + return __builtin_amdgcn_readfirstlane(value); +} + +template < + typename Object, + typename = std::enable_if_t && std::is_trivially_copyable_v>> +__device__ auto amd_wave_read_first_lane(const Object& obj) +{ + using Size = unsigned; + constexpr Size SgprSize = 4; + constexpr Size ObjectSize = sizeof(Object); + + auto* const from_obj = reinterpret_cast(&obj); + alignas(Object) std::byte to_obj[ObjectSize]; + + constexpr Size RemainedSize = ObjectSize % SgprSize; + constexpr Size CompleteSgprCopyBoundary = ObjectSize - RemainedSize; + for(Size offset = 0; offset < CompleteSgprCopyBoundary; offset += SgprSize) + { + using Sgpr = detail::get_carrier_t; + + *reinterpret_cast(to_obj + offset) = + amd_wave_read_first_lane(*reinterpret_cast(from_obj + offset)); + } + + if constexpr(0 < RemainedSize) + { + using Carrier = detail::get_carrier_t; + + *reinterpret_cast(to_obj + CompleteSgprCopyBoundary) = amd_wave_read_first_lane( + *reinterpret_cast(from_obj + CompleteSgprCopyBoundary)); + } + + /// NOTE: Implicitly start object lifetime. It's better to use std::start_lifetime_at() in this + /// scenario + return *reinterpret_cast(to_obj); +} + +} // namespace ck diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index bf091425485d75a244399f4781e1d44c4b38dd2c..dd7f0b770a11470037b1d33164d830cb9819e8d1 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_AMD_WMMA_HPP #define CK_AMD_WMMA_HPP diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index b4be0cbee70cbe0e6f6827b9c1f74fb9125e2a39..d00b7cd07882ba36c927dba05b7218e45a659b78 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_AMD_XDLOPS_HPP #define CK_AMD_XDLOPS_HPP @@ -297,6 +297,44 @@ struct intrin_mfma_i32_16x16x16i8<16, 16> } }; +template +struct intrin_mfma_i32_32x32x16i8; + +template <> +struct intrin_mfma_i32_32x32x16i8<32, 32> +{ + template + __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_i32_32x32x16_i8(bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + +template +struct intrin_mfma_i32_16x16x32i8; + +template <> +struct intrin_mfma_i32_16x16x32i8<16, 16> +{ + template + __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_i32_16x16x32i8(bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + template struct intrin_mfma_f64_16x16x4f64; @@ -306,7 +344,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> template __device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c) { -#ifdef __gfx90a__ +#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); #else diff --git a/include/ck/utility/array.hpp b/include/ck/utility/array.hpp index 370a457fe9d97a621e6f91e78f14733398ebf756..f63ce5e5a07a796888cb60ae8da0c855df75e7ff 100644 --- a/include/ck/utility/array.hpp +++ b/include/ck/utility/array.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_ARRAY_HPP #define CK_ARRAY_HPP diff --git a/include/ck/utility/array_multi_index.hpp b/include/ck/utility/array_multi_index.hpp index 9b8d5b95e9f6241b574104313b5b8fa2984ab18d..c0c1ea65fc77dd625aaaf6456a6628e09df27b51 100644 --- a/include/ck/utility/array_multi_index.hpp +++ b/include/ck/utility/array_multi_index.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_ARRAY_MULTI_INDEX_HPP #define CK_ARRAY_MULTI_INDEX_HPP diff --git a/include/ck/utility/c_style_pointer_cast.hpp b/include/ck/utility/c_style_pointer_cast.hpp index 6e8b0081587afcb2db159a05ecd5fb940def68ff..610e393a77216500448c3682b5db9b9860a077da 100644 --- a/include/ck/utility/c_style_pointer_cast.hpp +++ b/include/ck/utility/c_style_pointer_cast.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_C_STYLE_POINTER_CAST_HPP #define CK_C_STYLE_POINTER_CAST_HPP diff --git a/include/ck/utility/common_header.hpp b/include/ck/utility/common_header.hpp index 1378bbe448e2e1862fed04596fba7f4b106d03b7..f95660a8a47d2c93a65a497a22a74c4c0b6eaea4 100644 --- a/include/ck/utility/common_header.hpp +++ b/include/ck/utility/common_header.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -24,6 +24,7 @@ #include "ck/utility/tuple.hpp" #include "ck/utility/tuple_helper.hpp" #include "ck/utility/type.hpp" +#include "ck/utility/type_convert.hpp" #include "ck/utility/magic_division.hpp" #include "ck/utility/c_style_pointer_cast.hpp" #include "ck/utility/is_known_at_compile_time.hpp" @@ -33,6 +34,7 @@ #include "ck/utility/debug.hpp" #include "ck/utility/amd_buffer_addressing.hpp" +#include "ck/utility/amd_wave_read_first_lane.hpp" #include "ck/utility/generic_memory_space_atomic.hpp" #include "ck/utility/get_id.hpp" #include "ck/utility/thread_group.hpp" diff --git a/include/ck/utility/container_element_picker.hpp b/include/ck/utility/container_element_picker.hpp index abc5185e04a5079edb95fe15bfcd6788256d384a..838147e420cc7c95b63273b32ae3b2be2429c017 100644 --- a/include/ck/utility/container_element_picker.hpp +++ b/include/ck/utility/container_element_picker.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_CONTAINER_ELEMENT_PICKER_HPP #define CK_CONTAINER_ELEMENT_PICKER_HPP diff --git a/include/ck/utility/container_helper.hpp b/include/ck/utility/container_helper.hpp index c8b02bc5acaf8542bf028fff39c42061db3b3296..9c7b954565d386a8fdecd21052b102e750ab7102 100644 --- a/include/ck/utility/container_helper.hpp +++ b/include/ck/utility/container_helper.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_CONTAINER_HELPER_HPP #define CK_CONTAINER_HELPER_HPP diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 51900361441b6384cda6f8a31d9e76ddb1611439..c240afa2b867951059d59dd3e7e267edd816382e 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -12,6 +12,7 @@ using half_t = _Float16; #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 using int4_t = _BitInt(4); #endif +using f8_t = uint8_t; // vector_type template @@ -142,6 +143,13 @@ struct scalar_type }; #endif +template <> +struct scalar_type +{ + using type = f8_t; + static constexpr index_t vector_size = 1; +}; + // template struct vector_type @@ -898,6 +906,8 @@ struct vector_type } }; +using int64_t = long; + // fp64 using double2_t = typename vector_type::type; using double4_t = typename vector_type::type; @@ -942,71 +952,13 @@ using int8x16_t = typename vector_type::type; using int8x32_t = typename vector_type::type; using int8x64_t = typename vector_type::type; -// Convert X to Y -template -__host__ __device__ constexpr Y type_convert(X x) -{ - static_assert(!std::is_reference_v && !std::is_reference_v); - - return static_cast(x); -} - -// convert bfp16 to fp32 -template <> -inline __host__ __device__ constexpr float type_convert(bhalf_t x) -{ - union - { - uint32_t int32; - float fp32; - } u = {uint32_t(x) << 16}; - - return u.fp32; -} - -// convert fp32 to bfp16 -template <> -inline __host__ __device__ constexpr bhalf_t type_convert(float x) -{ - union - { - float fp32; - uint32_t int32; - } u = {x}; - - // When the exponent bits are not all 1s, then the value is zero, normal, - // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus - // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). - // This causes the bfloat16's mantissa to be incremented by 1 if the 16 - // least significant bits of the float mantissa are greater than 0x8000, - // or if they are equal to 0x8000 and the least significant bit of the - // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when - // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already - // has the value 0x7f, then incrementing it causes it to become 0x00 and - // the exponent is incremented by one, which is the next higher FP value - // to the unrounded bfloat16 value. When the bfloat16 value is subnormal - // with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up - // to a normal value with an exponent of 0x01 and a mantissa of 0x00. - // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, - // incrementing it causes it to become an exponent of 0xFF and a mantissa - // of 0x00, which is Inf, the next higher value to the unrounded value. - bool flag0 = ~u.int32 & 0x7f800000; - - // When all of the exponent bits are 1, the value is Inf or NaN. - // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero - // mantissa bit. Quiet NaN is indicated by the most significant mantissa - // bit being 1. Signaling NaN is indicated by the most significant - // mantissa bit being 0 but some other bit(s) being 1. If any of the - // lower 16 bits of the mantissa are 1, we set the least significant bit - // of the bfloat16 mantissa, in order to preserve signaling NaN in case - // the bfloat16's mantissa bits are all 0. - bool flag1 = !flag0 && (u.int32 & 0xffff); - - u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even - u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN - - return uint16_t(u.int32 >> 16); -} +// f8 +using f8x2_t = typename vector_type::type; +using f8x4_t = typename vector_type::type; +using f8x8_t = typename vector_type::type; +using f8x16_t = typename vector_type::type; +using f8x32_t = typename vector_type::type; +using f8x64_t = typename vector_type::type; template struct NumericLimits @@ -1054,4 +1006,21 @@ struct NumericLimits }; #endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min = 0x08; // 0b00001000 + static constexpr uint8_t binary_max = 0x77; // 0b01110111 + static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 + static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + + __host__ __device__ static constexpr f8_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr f8_t Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr f8_t Lowest() { return bit_cast(binary_lowest); } + + __host__ __device__ static constexpr f8_t QuietNaN() { return bit_cast(binary_qnan); } +}; + } // namespace ck diff --git a/include/ck/utility/debug.hpp b/include/ck/utility/debug.hpp index 593bbb711672f7dc6e3db22d877bf249b5dff56d..80346f0d9f6f9e5a6a28dcab4e8a7666639394ca 100644 --- a/include/ck/utility/debug.hpp +++ b/include/ck/utility/debug.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef UTILITY_DEBUG_HPP #define UTILITY_DEBUG_HPP diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index c6f0d299ef3c35c63ad97424888545edd2874914..02d61f34ed5192480d09f4432eec30e65e7264b9 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -19,7 +19,8 @@ namespace ck { template + bool InvalidElementUseNumericalZeroValue, + AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence> struct DynamicBuffer { using type = T; @@ -77,13 +78,16 @@ struct DynamicBuffer if constexpr(InvalidElementUseNumericalZeroValue) { - return amd_buffer_load_invalid_element_return_zero, t_per_x>( + return amd_buffer_load_invalid_element_return_zero, + t_per_x, + coherence>( p_data_, i, is_valid_element, element_space_size_); } else { return amd_buffer_load_invalid_element_return_customized_value, - t_per_x>( + t_per_x, + coherence>( p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); } } @@ -173,7 +177,7 @@ struct DynamicBuffer { constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; - amd_buffer_store, t_per_x>( + amd_buffer_store, t_per_x, coherence>( x, p_data_, i, is_valid_element, element_space_size_); } else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds && @@ -376,14 +380,19 @@ struct DynamicBuffer __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } }; -template +template __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size) { - return DynamicBuffer{p, element_space_size}; + return DynamicBuffer{ + p, element_space_size}; } template < AddressSpaceEnum BufferAddressSpace, + AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence, typename T, typename ElementSpaceSize, typename X, @@ -391,7 +400,7 @@ template < __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element_value) { - return DynamicBuffer{ + return DynamicBuffer{ p, element_space_size, invalid_element_value}; } diff --git a/include/ck/utility/enable_if.hpp b/include/ck/utility/enable_if.hpp index 297434b0dddd5f1a680176c3bd099bedfda8dff4..c0a3c99f1fdafea9f151fe9fc319c2f7aaa0ffda 100644 --- a/include/ck/utility/enable_if.hpp +++ b/include/ck/utility/enable_if.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/f8_utils.hpp b/include/ck/utility/f8_utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bb13f98154ef4ec072529c26ea993f8e1a444abf --- /dev/null +++ b/include/ck/utility/f8_utils.hpp @@ -0,0 +1,250 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +namespace ck { + +// fp8 rounding modes +// use standard for rounding to nearest, the faster one +// use stochastic for stochastic rounding, helps to avoid error accumulation +enum class f8_rounding_mode +{ + standard, + stochastic +}; + +} // namespace ck + +namespace ck::utils { + +namespace { + +template +__host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) +{ + // check data type + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + + // fp8 exponent/mantissa layout + constexpr int f8_exp = 4; + constexpr int f8_mant = 3; + + // resulting type exponent/mantissa layout + constexpr int type_exp = is_half ? 5 : 8; + constexpr int type_mant = is_half ? 10 : 23; + + int exponent; + uint32_t head, mantissa, sign; + // nan code is same for float and half + constexpr uint8_t nan_code = 0x80; + constexpr uint32_t nan_mask = is_half ? 0x7C00 : 0x7F800000; + + // convert to bitwise + typedef typename std::conditional::value, uint16_t, uint32_t>::type + T_bitwise; + T_bitwise x_bitwise = *(reinterpret_cast(&x)); + + // unpack the input, depends on datatype + if constexpr(is_float) + { + head = x_bitwise & 0xFF800000; + mantissa = x_bitwise & 0x7FFFFF; + exponent = (head >> type_mant) & 0xFF; + sign = head >> (type_exp + type_mant); + } + else if constexpr(is_half) + { + head = x_bitwise & 0xFC00; + mantissa = x_bitwise & 0x3FF; + exponent = (head >> type_mant) & 0x1F; + sign = head >> (type_exp + type_mant); + } + + uint32_t signed_inf = (sign << (type_exp + type_mant)) + (((1 << type_exp) - 1) << type_mant); + uint32_t drop_mask = (1 << (type_mant - f8_mant)) - 1; + constexpr int max_exp = (1 << f8_exp) - (negative_zero_nan ? 1 : 2); + constexpr int exp_low_cutoff = + (1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); + + if constexpr(negative_zero_nan) + { + if((x_bitwise & nan_mask) == nan_mask) + return nan_code; + } + else + { + if((x_bitwise & nan_mask) == nan_mask) + return signed_inf + (mantissa != 0 ? 1 : 0); + } + + // check if x is 0.0 + if(x_bitwise == 0) + return 0; + + exponent -= exp_low_cutoff - 1; + if(exponent <= 0) + drop_mask = (1 << (type_mant - f8_mant + 1 - exponent)) - 1; + mantissa += 1 << type_mant; + // apply random number if needed + mantissa += (stoch ? rng : mantissa) & drop_mask; + if(mantissa >= (2 << type_mant)) + { + mantissa >>= 1; + exponent++; + } + mantissa >>= (type_mant - f8_mant); + + // check negative exponent + if(exponent <= 0) + { + if(x_bitwise == 0) + return 0; + else + { + // subnormal range; represented by a subnormal float8 (exponent 0) + // and involves loss of accuracy + mantissa >>= 1 - exponent; + exponent = 0; + } + } + // above range: quantize to maximum possible float of the same sign + else if(exponent > max_exp) + { + if(clip) + { + mantissa = (1 << f8_mant) - 1; + exponent = max_exp; + } + else + { + return signed_inf; + } + } + + // check if x is 0.0 or -0.0 + if(exponent == 0 && mantissa == 0) + return negative_zero_nan ? 0 : (sign << (f8_exp + f8_mant)); + mantissa &= (1 << f8_mant) - 1; + return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa; +} + +template +__host__ __device__ T run_cast_from_f8(f8_t x) +{ + // check data type + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + + // fp8 exponent/mantissa layout + constexpr int f8_exp = 4; + constexpr int f8_mant = 3; + + // resulting type exponent/mantissa layout + constexpr int type_exp = is_half ? 5 : 8; + constexpr int type_mant = is_half ? 10 : 23; + + // prepare the codes + constexpr uint8_t nan_code = 0x80; + T fInf, fNegInf, fNaN, fNeg0; + if constexpr(is_half) + { + constexpr uint16_t ihInf = 0x7C00; + constexpr uint16_t ihNegInf = 0xFC00; + constexpr uint16_t ihNaN = 0x7C01; + constexpr uint16_t ihNeg0 = 0x8000; + fInf = *(reinterpret_cast(&ihInf)); + fNegInf = *(reinterpret_cast(&ihNegInf)); + fNaN = *(reinterpret_cast(&ihNaN)); + fNeg0 = *(reinterpret_cast(&ihNeg0)); + } + else if constexpr(is_float) + { + constexpr uint32_t ifInf = 0x7F800000; + constexpr uint32_t ifNegInf = 0xFF800000; + constexpr uint32_t ifNaN = 0x7F800001; + constexpr uint32_t ifNeg0 = 0x80000000; + fInf = *(reinterpret_cast(&ifInf)); + fNegInf = *(reinterpret_cast(&ifNegInf)); + fNaN = *(reinterpret_cast(&ifNaN)); + fNeg0 = *(reinterpret_cast(&ifNeg0)); + } + + // unpack the input + uint32_t sign = x >> (f8_exp + f8_mant); + uint32_t mantissa = x & ((1 << f8_mant) - 1); + int exponent = (x & 0x7F) >> f8_mant; + + constexpr int exp_low_cutoff = + (1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); + typename std::conditional::value, uint16_t, uint32_t>::type retval; + + if constexpr(negative_zero_nan) + { + if(x == nan_code) + return fNaN; + } + else + { + if(x == nan_code) + return fNeg0; + if(exponent == ((1 << f8_exp) - 1)) + return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; + } + + // subnormal input + if(exponent == 0) + { + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above + int sh = 1 + __builtin_clz(mantissa) - ((1 + type_exp + type_mant) - f8_mant); + mantissa <<= sh; + mantissa &= ((1 << f8_mant) - 1); + exponent += 1 - sh; + } + exponent += exp_low_cutoff - 1; + mantissa <<= type_mant - f8_mant; + + // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) + if(exponent <= 0) + { + mantissa |= 1 << type_mant; + mantissa >>= 1 - exponent; + exponent = 0; + } + + retval = (sign << (type_exp + type_mant)) | (exponent << type_mant) | mantissa; + return *(reinterpret_cast(&retval)); +} + +} // namespace + +template +__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng) +{ + // check datatype + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "Only half and float can be casted to f8."); + + return run_cast_to_f8(x, rng); +} + +template +__host__ __device__ T cast_from_f8(f8_t x) +{ + // check datatype + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "only half and float are supported."); + + // check if x is 0.0 + if(x == 0) + return static_cast(0); + + return run_cast_from_f8(x); +} + +} // namespace ck::utils diff --git a/include/ck/utility/functional.hpp b/include/ck/utility/functional.hpp index 08e730782f386cf5788c64bc04a008dc6cb37b28..91797d24092e3e32ad4a6bd40958952b124d9978 100644 --- a/include/ck/utility/functional.hpp +++ b/include/ck/utility/functional.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/functional2.hpp b/include/ck/utility/functional2.hpp index 6f125ca4c944777d02ae1083334bec4602ad68c4..99c65f4eb85b67231557b1916fae10c6568e676d 100644 --- a/include/ck/utility/functional2.hpp +++ b/include/ck/utility/functional2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/functional3.hpp b/include/ck/utility/functional3.hpp index 06b67ef7e3fdec25a1ab2e18e23c904342233ebe..97605a7adeb8ae64e4e5a32debe9386295923b6c 100644 --- a/include/ck/utility/functional3.hpp +++ b/include/ck/utility/functional3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/functional4.hpp b/include/ck/utility/functional4.hpp index 6eeaf15c9b7ac283f7fcac96c195d28f179b6065..b5f3df8d7c517dfaf01320e41721da174883c2d9 100644 --- a/include/ck/utility/functional4.hpp +++ b/include/ck/utility/functional4.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_FUNCTIONAL4_HPP #define CK_FUNCTIONAL4_HPP diff --git a/include/ck/utility/generic_memory_space_atomic.hpp b/include/ck/utility/generic_memory_space_atomic.hpp index 6a1ca966521710ecdb4ceaad9a83457031f4c508..98f40a4363aa2ddf2908fa497151f756b77d6f94 100644 --- a/include/ck/utility/generic_memory_space_atomic.hpp +++ b/include/ck/utility/generic_memory_space_atomic.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "data_type.hpp" diff --git a/include/ck/utility/get_id.hpp b/include/ck/utility/get_id.hpp index 44ff438155d2a9c9ba9b0925d4e9fe95c1fa6bce..77564c6130baf45bfd331e69fc437fb7c7c96d18 100644 --- a/include/ck/utility/get_id.hpp +++ b/include/ck/utility/get_id.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/ignore.hpp b/include/ck/utility/ignore.hpp index ac33cbf9a508f5e6402d17d6285ae34a8196f6eb..f70a182fd4e5c6cc623acb2c65b5551b2a1acd14 100644 --- a/include/ck/utility/ignore.hpp +++ b/include/ck/utility/ignore.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/inner_product.hpp b/include/ck/utility/inner_product.hpp index b65640bfffed31e319549b7152f0c2a93bfaded9..b13bccb5ab3fd4e1a01470aee97bca78d93e286c 100644 --- a/include/ck/utility/inner_product.hpp +++ b/include/ck/utility/inner_product.hpp @@ -1,8 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "data_type.hpp" +#include "type_convert.hpp" namespace ck { diff --git a/include/ck/utility/integral_constant.hpp b/include/ck/utility/integral_constant.hpp index 9aab4e24214a884bdfb391e0f9040bb3af2630ec..376070eb3d8ac326603b71e52e76949c168f4219 100644 --- a/include/ck/utility/integral_constant.hpp +++ b/include/ck/utility/integral_constant.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/is_known_at_compile_time.hpp b/include/ck/utility/is_known_at_compile_time.hpp index 8198154422e5d5228b79c6539bddcd5070d1e25c..2cafc3e6f2fafd247e378446a15a7e16c019c914 100644 --- a/include/ck/utility/is_known_at_compile_time.hpp +++ b/include/ck/utility/is_known_at_compile_time.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/magic_division.hpp b/include/ck/utility/magic_division.hpp index a5e8e9216519074d6542d379512c5b8d13ec21c9..f19030d4e9ca38971f3ffbfd5e4ff6b951eb16b3 100644 --- a/include/ck/utility/magic_division.hpp +++ b/include/ck/utility/magic_division.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/math.hpp b/include/ck/utility/math.hpp index 12203bd7f3146955b2568dcfb40c26403ff4b70a..326b0e61eff611512bd668355358de421f9f376a 100644 --- a/include/ck/utility/math.hpp +++ b/include/ck/utility/math.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -168,6 +168,10 @@ __device__ double exp(double x) return exp(x); } +static inline __host__ float exp(float x) { return std::expf(x); } + +static inline __host__ double exp(double x) { return std::exp(x); } + // greatest common divisor, aka highest common factor __host__ __device__ constexpr index_t gcd(index_t x, index_t y) { diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp index 4febace0b843ecbd5b785ba96bc9f36c3a58a73b..1cac2cc0c7172b4b1915845d6c7f3fa853653683 100644 --- a/include/ck/utility/math_v2.hpp +++ b/include/ck/utility/math_v2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -92,6 +92,15 @@ static inline __host__ float sqrt(float x) { return std::sqrt(x); }; static inline __host__ double sqrt(double x) { return std::sqrt(x); }; +static inline __host__ half_t tanh(half_t x) +{ + return static_cast(std::tanh(static_cast(x))); +}; + +static inline __host__ float tanh(float x) { return std::tanh(x); }; + +static inline __host__ double tanh(double x) { return std::tanh(x); }; + // math functions for the HIP kernel, some are implemented by calling hip builtin functions static inline __device__ float abs(float x) { return ::abs(x); }; @@ -172,5 +181,14 @@ static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); static inline __device__ double sqrt(double x) { return __builtin_amdgcn_sqrt(x); }; +static inline __device__ half_t tanh(half_t x) +{ + return static_cast(::tanhf(static_cast(x))); +}; + +static inline __device__ float tanh(float x) { return ::tanhf(x); }; + +static inline __device__ double tanh(double x) { return ::tanh(x); }; + } // namespace math } // namespace ck diff --git a/include/ck/utility/multi_index.hpp b/include/ck/utility/multi_index.hpp index 1d544c0906cae3c61f4d7ce27e74e7636c3919b5..9f7ba8bff63b3f6e625b98f66accd875ee118594 100644 --- a/include/ck/utility/multi_index.hpp +++ b/include/ck/utility/multi_index.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/number.hpp b/include/ck/utility/number.hpp index f3ca6b61dc6ac08330a6cf1633148bb9bad8cd81..d29afd31a7ef03ef3fdbf11f8e9a88132e07fe82 100644 --- a/include/ck/utility/number.hpp +++ b/include/ck/utility/number.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_NUMBER_HPP #define CK_NUMBER_HPP diff --git a/include/ck/utility/print.hpp b/include/ck/utility/print.hpp deleted file mode 100644 index eed1ca42c7346e32f4d6924452d568bf431d6c24..0000000000000000000000000000000000000000 --- a/include/ck/utility/print.hpp +++ /dev/null @@ -1,25 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_PRINT_HPP -#define CK_PRINT_HPP - -#include "array.hpp" -#include "statically_indexed_array.hpp" -#include "container_helper.hpp" -#include "sequence.hpp" - -namespace ck { - -template -__host__ __device__ void print_array(const char* s, T a) -{ - constexpr index_t nsize = a.Size(); - - printf("%s size %d, {", s, nsize); - static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); }); - printf("}\n"); -} - -} // namespace ck -#endif diff --git a/include/ck/utility/random_gen.hpp b/include/ck/utility/random_gen.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b7edf26507c62d5365ecab3fe1660c39a2cd672f --- /dev/null +++ b/include/ck/utility/random_gen.hpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { + +// Pseudo random number generator +// version for fp32 +template {}, bool> = false> +__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) +{ + uint32_t x = *(reinterpret_cast(&val)); + uint32_t drop_bits = uint32_t(x) & 0xFFFFu; + drop_bits ^= x >> 16; + drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5); + drop_bits *= 0x7000149; + // NOTE: If id is in 64 bit, we are only using lower 32 bit. + // So, it can have an effect of using same id for multiple elements when the id is very + // large! + uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed); + return rng; +} + +// version for fp16 +template {}, bool> = false> +__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) +{ + uint16_t x = *(reinterpret_cast(&val)); + uint32_t drop_bits = uint32_t(x) & 0xFFFFu; + drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5); + drop_bits *= 0x7000149; + // NOTE: If id is in 64 bit, we are only using lower 32 bit. + // So, it can have an effect of using same id for multiple elements when the id is very + // large! + uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed); + return rng; +} + +// return 0 if data is not fp16 or fp32 +template {} || std::is_same{}), bool> = false> +__host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t) +{ + std::ignore = id; + std::ignore = val; + std::ignore = seed; + + return 0; +} + +} // namespace ck diff --git a/include/ck/utility/reduction_common.hpp b/include/ck/utility/reduction_common.hpp index aceef7b296da7466c484aad8033fe1f949662fca..3777d297c8c924cbf4a58a4fa79751be2cfa2f44 100644 --- a/include/ck/utility/reduction_common.hpp +++ b/include/ck/utility/reduction_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/reduction_enums.hpp b/include/ck/utility/reduction_enums.hpp index 67856331059cef628a62e4b2223f262b839bbf02..23b7149f8eb022a00d1aed81b4b69c315ca37ba0 100644 --- a/include/ck/utility/reduction_enums.hpp +++ b/include/ck/utility/reduction_enums.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/reduction_functions_accumulate.hpp b/include/ck/utility/reduction_functions_accumulate.hpp index 724e5599d6c878ba08bf14429c40c0958ab9f2a2..b9765ff0d2cb57bb2a1a0c4c7684a3cd2e0fa008 100644 --- a/include/ck/utility/reduction_functions_accumulate.hpp +++ b/include/ck/utility/reduction_functions_accumulate.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/reduction_operator.hpp b/include/ck/utility/reduction_operator.hpp index b4e770a64efdda0b42e7996e2858d0f162231fee..36c25203ea12b7a30b1991762c229da7cfcf417a 100644 --- a/include/ck/utility/reduction_operator.hpp +++ b/include/ck/utility/reduction_operator.hpp @@ -1,11 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/ck.hpp" #include "ck/utility/data_type.hpp" #include "ck/utility/type.hpp" +#include "ck/utility/type_convert.hpp" namespace ck { diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index 97b597221c2850d4b28b644266978fc56b295913..d6bfb2eba1cd6a6de2e334261379504d8448b3c7 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/sequence_helper.hpp b/include/ck/utility/sequence_helper.hpp index db25c27e70c3653f94524367b8a6bce79480113e..8c493a28221eae4e095f051264f9323429984531 100644 --- a/include/ck/utility/sequence_helper.hpp +++ b/include/ck/utility/sequence_helper.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/span.hpp b/include/ck/utility/span.hpp index 1e501214547cfbcec25921d6526e77563504835a..5e7567a847f3d18e424437005d876a409cd0d99d 100644 --- a/include/ck/utility/span.hpp +++ b/include/ck/utility/span.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/static_buffer.hpp b/include/ck/utility/static_buffer.hpp index dd25c96203288cfc9aa0959e9a7b1d83ef4abfc2..835f5657307a8e171b116c0bec088bb84b1b7b38 100644 --- a/include/ck/utility/static_buffer.hpp +++ b/include/ck/utility/static_buffer.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/statically_indexed_array.hpp b/include/ck/utility/statically_indexed_array.hpp index 3438776f413cf010664cc8aa4a18c09bf161fae7..a2d70045a4b9b221651d958472aa6a2e721ccccc 100644 --- a/include/ck/utility/statically_indexed_array.hpp +++ b/include/ck/utility/statically_indexed_array.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_STATICALLY_INDEXED_ARRAY_HPP #define CK_STATICALLY_INDEXED_ARRAY_HPP diff --git a/include/ck/utility/statically_indexed_array_multi_index.hpp b/include/ck/utility/statically_indexed_array_multi_index.hpp index 21b2941b21401ac561932a45ec304536365f2927..4a8b96ae8a6b81400f3567957f07010c26c0bf8d 100644 --- a/include/ck/utility/statically_indexed_array_multi_index.hpp +++ b/include/ck/utility/statically_indexed_array_multi_index.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP #define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP diff --git a/include/ck/utility/synchronization.hpp b/include/ck/utility/synchronization.hpp index 0e247ed0f8df15b890cdd44ea9f1682d5823a89f..775e7ac3a392f5edd16054e63314bc80faca5d6e 100644 --- a/include/ck/utility/synchronization.hpp +++ b/include/ck/utility/synchronization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/thread_group.hpp b/include/ck/utility/thread_group.hpp index d469dec899a556c5b8062efa775b7dc694500fe8..1cd6b2f3ce26a9a89e98600f7377aac0d93af344 100644 --- a/include/ck/utility/thread_group.hpp +++ b/include/ck/utility/thread_group.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/transpose_vectors.hpp b/include/ck/utility/transpose_vectors.hpp index 2b0075d6005e3c5cf7d6772fda8132a6404878ed..6faf5c133ba2df0ed938aab915a272e3a7a590f3 100644 --- a/include/ck/utility/transpose_vectors.hpp +++ b/include/ck/utility/transpose_vectors.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/tuple.hpp b/include/ck/utility/tuple.hpp index d8664be550b79bc4de2edd224bc63350a0a1bde7..b616b3123f3b9ce3c92763cd36ecd39fc7d3e553 100644 --- a/include/ck/utility/tuple.hpp +++ b/include/ck/utility/tuple.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/tuple_helper.hpp b/include/ck/utility/tuple_helper.hpp index 6f5b142a5e7a765807bd9d7f556f7b8afc512d37..e39ae1c23d421e380b81b4bc7b487025cb6937be 100644 --- a/include/ck/utility/tuple_helper.hpp +++ b/include/ck/utility/tuple_helper.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/type.hpp b/include/ck/utility/type.hpp index 90b9df2950b7d979f8cef386d155c72f6d7a39a5..9609afba43a78c05332d459b3a7cf756931a6e60 100644 --- a/include/ck/utility/type.hpp +++ b/include/ck/utility/type.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ed83964936c4fc00dc88a2cc17ff117bb47a8e41 --- /dev/null +++ b/include/ck/utility/type_convert.hpp @@ -0,0 +1,230 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/f8_utils.hpp" +#include "ck/utility/random_gen.hpp" + +namespace ck { + +// Convert X to Y +template +__host__ __device__ constexpr Y type_convert(X x) +{ + static_assert(!std::is_reference_v && !std::is_reference_v); + + return static_cast(x); +} + +// convert bfp16 to fp32 +template <> +inline __host__ __device__ constexpr float type_convert(bhalf_t x) +{ + union + { + uint32_t int32; + float fp32; + } u = {uint32_t(x) << 16}; + + return u.fp32; +} + +// convert fp32 to bfp16 +template <> +inline __host__ __device__ constexpr bhalf_t type_convert(float x) +{ + union + { + float fp32; + uint32_t int32; + } u = {x}; + + return uint16_t(u.int32 >> 16); +} + +// convert bfp16 to fp16 via fp32 +template <> +inline __host__ __device__ constexpr half_t type_convert(bhalf_t x) +{ + float x_fp32 = type_convert(x); + + return static_cast(x_fp32); +} + +// convert fp16 to bfp16 via fp32 +template <> +inline __host__ __device__ constexpr bhalf_t type_convert(half_t x) +{ + float x_fp32 = static_cast(x); + + return type_convert(x_fp32); +} + +// convert bfp16 to int32 via fp32 +template <> +inline __host__ __device__ constexpr int32_t type_convert(bhalf_t x) +{ + float x_fp32 = type_convert(x); + + return static_cast(x_fp32); +} + +// convert int32 to bfp16 via fp32 +template <> +inline __host__ __device__ constexpr bhalf_t type_convert(int32_t x) +{ + float x_fp32 = static_cast(x); + + return type_convert(x_fp32); +} + +// convert bfp16 to int8 via fp32 +template <> +inline __host__ __device__ constexpr int8_t type_convert(bhalf_t x) +{ + float x_fp32 = type_convert(x); + + return static_cast(x_fp32); +} + +// convert int8 to bfp16 via fp32 +template <> +inline __host__ __device__ constexpr bhalf_t type_convert(int8_t x) +{ + float x_fp32 = static_cast(x); + + return type_convert(x_fp32); +} + +// convert fp32 to fp8 +template <> +inline __host__ __device__ f8_t type_convert(float x) +{ + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr f8_rounding_mode rm = f8_rounding_mode::standard; + constexpr uint32_t rng = 0; + return utils::cast_to_f8( + x, rng); +} + +// convert fp8 to fp32 +template <> +inline __host__ __device__ float type_convert(f8_t x) +{ + constexpr bool negative_zero_nan = true; + return utils::cast_from_f8(x); +} + +// convert fp16 to fp8 +template <> +inline __host__ __device__ f8_t type_convert(half_t x) +{ + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr f8_rounding_mode rm = f8_rounding_mode::standard; + constexpr uint32_t rng = 0; + return utils::cast_to_f8( + x, rng); +} + +// convert fp8 to fp16 +template <> +inline __host__ __device__ half_t type_convert(f8_t x) +{ + constexpr bool negative_zero_nan = true; + return utils::cast_from_f8(x); +} + +// Declare a template function for bf16 conversion using RTN +template +__host__ __device__ constexpr Y bf16_convert_rtn(X x); + +// Convert fp32 to bf16 with RTN if higher precision is needed +template <> +inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(float x) +{ + union + { + float fp32; + uint32_t int32; + } u = {x}; + + // When the exponent bits are not all 1s, then the value is zero, normal, + // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus + // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). + // This causes the bfloat16's mantissa to be incremented by 1 if the 16 + // least significant bits of the float mantissa are greater than 0x8000, + // or if they are equal to 0x8000 and the least significant bit of the + // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when + // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already + // has the value 0x7f, then incrementing it causes it to become 0x00 and + // the exponent is incremented by one, which is the next higher FP value + // to the unrounded bfloat16 value. When the bfloat16 value is subnormal + // with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up + // to a normal value with an exponent of 0x01 and a mantissa of 0x00. + // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, + // incrementing it causes it to become an exponent of 0xFF and a mantissa + // of 0x00, which is Inf, the next higher value to the unrounded value. + bool flag0 = ~u.int32 & 0x7f800000; + + // When all of the exponent bits are 1, the value is Inf or NaN. + // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero + // mantissa bit. Quiet NaN is indicated by the most significant mantissa + // bit being 1. Signaling NaN is indicated by the most significant + // mantissa bit being 0 but some other bit(s) being 1. If any of the + // lower 16 bits of the mantissa are 1, we set the least significant bit + // of the bfloat16 mantissa, in order to preserve signaling NaN in case + // the bfloat16's mantissa bits are all 0. + bool flag1 = !flag0 && (u.int32 & 0xffff); + + u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even + u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN + + return uint16_t(u.int32 >> 16); +} + +// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed +template <> +inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(half_t x) +{ + float x_fp32 = static_cast(x); + + return bf16_convert_rtn(x_fp32); +} + +// Declare a template function for fp8 conversion using SR +template +__host__ __device__ constexpr Y f8_convert_sr(X x); + +// convert fp32 to fp8 with stochastic rounding +template <> +inline __host__ __device__ f8_t f8_convert_sr(float x) +{ + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; + constexpr int seed = 42; + // as thread id is not available on host, use 0 for prn generation + uint32_t rng = prand_generator(reinterpret_cast(&x), x); + return utils::cast_to_f8( + x, rng); +} + +// convert fp16 to fp8 with stochastic rounding +template <> +inline __host__ __device__ f8_t f8_convert_sr(half_t x) +{ + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; + constexpr int seed = 42; + // as thread id is not available on host, use 0 for prn generation + uint32_t rng = prand_generator(reinterpret_cast(&x), x); + return utils::cast_to_f8( + x, rng); +} + +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp index 46a1fa559a1a3d164acc99e536f7e26a83b97ed4..a1b1e0d91b472e5531097b13d57d6efe5b9591a7 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp index 0b621e88a0c6e890bf1efb026615fcc63fadf6c9..a2eabdf5c1c92375503fc8cc961a2cde8410204e 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp index dd0db316804fd191ee73d9952f6df6779420910d..20c1fcd7367838022d7a4e2776f436ce796460fe 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer.hpp index 463c655ac1d1414f80f9d8b7fcfc696be6326018..7d652fe4c41d044747c2c40a3d6e7383d4f1ff2e 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp index b0149d88fdb062806d3b34b400f9b6d3389419f9..24f754e5987adfa82f4b0cc0e41dc85431f94e12 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp new file mode 100644 index 0000000000000000000000000000000000000000..92a8c82a6e81c85ebe79aa8941f51d0a53e4734d --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// hardcoded for NumDimM == NumDimN == NumDimK == 2 +template = false> +struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator +{ + // Argument + struct Argument : public ck::tensor_operation::device::BaseArgument + { + Argument(const Tensor& a_ms_ks, + const Tensor& b_ns_ks, + Tensor& c_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op) + : a_ms_ks_{a_ms_ks}, + b_ns_ks_{b_ns_ks}, + c_ms_ns_{c_ms_ns}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op} + { + } + + const Tensor& a_ms_ks_; + const Tensor& b_ns_ks_; + Tensor& c_ms_ns_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + }; + + // Invoker + struct Invoker : public ck::tensor_operation::device::BaseInvoker + { + using Argument = ReferenceContraction_M2_N2_K2::Argument; + + float Run(const Argument& arg) + { + auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) { + const ck::index_t K0 = arg.a_ms_ks_.mDesc.GetLengths()[2]; + const ck::index_t K1 = arg.a_ms_ks_.mDesc.GetLengths()[3]; + + AccDataType v_acc = 0; + + for(ck::index_t k0 = 0; k0 < K0; ++k0) + { + for(ck::index_t k1 = 0; k1 < K1; ++k1) + { + AccDataType v_a; + AccDataType v_b; + + arg.a_element_op_( + v_a, ck::type_convert(arg.a_ms_ks_(m0, m1, k0, k1))); + arg.b_element_op_( + v_b, ck::type_convert(arg.b_ns_ks_(n0, n1, k0, k1))); + + v_acc += v_a * v_b; + } + } + + arg.c_ms_ns_(m0, m1, n0, n1) = v_acc; + }; + + make_ParallelTensorFunctor(f_ms_ns, + arg.c_ms_ns_.mDesc.GetLengths()[0], + arg.c_ms_ns_.mDesc.GetLengths()[1], + arg.c_ms_ns_.mDesc.GetLengths()[2], + arg.c_ms_ns_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const ck::tensor_operation::device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override + { + return true; + } + + static auto MakeArgument(const Tensor& a_ms_ks, + const Tensor& b_ns_ks, + Tensor& c_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op) + { + return Argument{a_ms_ks, b_ns_ks, c_ms_ns, a_element_op, b_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceContraction_M2_N2_K2" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp index 225f7b7e36f3a5a0d70058ed7089fc693a9e0493..449734f43478f0b36215d9d876c4a7537675a13f 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp index 7d62158f00c0aa28fdfe90d4a9df537367b22717..ec5df238ab2c2a292d8d44951f67e2d8d2383d1a 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp index b8d47d218b95e3ff2b38c2fc965b6bdc6a2afc68..8f4182a2318017dc7239dcc32e34af5287e97b4d 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp index be22003fd90db20eaea80045bd34f908ed8ebb61..71c84a1f5cd43811105a4630e31ade8ae0b3a075 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp index f949f27fde973cc44eec3162c7625980a9818116..0b90b4b50eec00cbd4d807fe2f43b9280cd2bdf2 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 6728bb1f471b4aad3f7ed0a2878fc7a81e1b569d..9b797be92549703fff3397fe9f099ad3e76c8389 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -1,11 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/library/utility/host_tensor.hpp" @@ -66,8 +67,26 @@ struct ReferenceGemm : public device::BaseOperator ADataType v_a; BDataType v_b; - arg.a_element_op_(v_a, arg.a_m_k_(m, k)); - arg.b_element_op_(v_b, arg.b_k_n_(k, n)); + // use PassThrough instead of ConvertBF16RTN for reference calculation + if constexpr(is_same_v) + { + ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k)); + } + else + { + arg.a_element_op_(v_a, arg.a_m_k_(m, k)); + } + // same for B matrix + if constexpr(is_same_v) + { + ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n)); + } + else + { + arg.b_element_op_(v_b, arg.b_k_n_(k, n)); + } v_acc += ck::type_convert(v_a) * ck::type_convert(v_b); diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp deleted file mode 100644 index c77d22f4cd1c34d3fa81a973f115fc37b7d03cf0..0000000000000000000000000000000000000000 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp +++ /dev/null @@ -1,136 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include - -#include "ck/tensor_operation/gpu/device/device_base.hpp" -#include "ck/library/utility/host_tensor.hpp" - -namespace ck { -namespace tensor_operation { -namespace host { - -template -struct ReferenceGemmBias2D : public device::BaseOperator -{ - // Argument - struct Argument : public device::BaseArgument - { - Argument(const Tensor& a_m_k, - const Tensor& b_k_n, - const Tensor& c0_m_n, - Tensor& c_m_n, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - : a_m_k_{a_m_k}, - b_k_n_{b_k_n}, - c0_m_n_{c0_m_n}, - c_m_n_{c_m_n}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op} - { - } - - const Tensor& a_m_k_; - const Tensor& b_k_n_; - const Tensor& c0_m_n_; - Tensor& c_m_n_; - - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CElementwiseOperation c_element_op_; - }; - - // Invoker - struct Invoker : public device::BaseInvoker - { - using Argument = ReferenceGemmBias2D::Argument; - - float Run(const Argument& arg) - { - auto f_mk_kn_mn = [&](auto m, auto n) { - const int K = arg.a_m_k_.mDesc.GetLengths()[1]; - - AccDataType a = 0; - AccDataType b = 0; - AccDataType acc = 0; - - for(int k = 0; k < K; ++k) - { - arg.a_element_op_(a, ck::type_convert(arg.a_m_k_(m, k))); - arg.b_element_op_(b, ck::type_convert(arg.b_k_n_(k, n))); - acc += a * b; - } - - CDataType cast_acc = static_cast(acc); - arg.c_element_op_(arg.c_m_n_(m, n), cast_acc, arg.c0_m_n_(m, n)); - }; - - make_ParallelTensorFunctor( - f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - - return 0; - } - - float Run(const device::BaseArgument* p_arg, - const StreamConfig& /* stream_config */ = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg)); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - bool IsSupportedArgument(const device::BaseArgument*) override { return true; } - - static auto MakeArgument(const Tensor& a_m_k, - const Tensor& b_k_n, - const Tensor& c0_m_n, - Tensor& c_m_n, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - { - return Argument{a_m_k, b_k_n, c0_m_n, c_m_n, a_element_op, b_element_op, c_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - virtual std::unique_ptr MakeInvokerPointer() - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "ReferenceGemmBias2D" - << std::endl; - // clang-format on - - return str.str(); - } -}; - -} // namespace host -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp deleted file mode 100644 index 7dfc3c1ed4b868756277f43d341cca5a1796662e..0000000000000000000000000000000000000000 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp +++ /dev/null @@ -1,140 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include - -#include "ck/tensor_operation/gpu/device/device_base.hpp" - -#include "ck/library/utility/host_tensor.hpp" - -namespace ck { -namespace tensor_operation { -namespace host { - -template -struct ReferenceGemmBiasActivation : public device::BaseOperator -{ - // Argument - struct Argument : public device::BaseArgument - { - Argument(const Tensor& a_m_k, - const Tensor& b_k_n, - Tensor& c_m_n, - const Tensor& c0_n, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - : a_m_k_{a_m_k}, - b_k_n_{b_k_n}, - c_m_n_{c_m_n}, - c0_n_{c0_n}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op} - { - } - - const Tensor& a_m_k_; - const Tensor& b_k_n_; - Tensor& c_m_n_; - const Tensor& c0_n_; - - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CElementwiseOperation c_element_op_; - }; - - // Invoker - struct Invoker : public device::BaseInvoker - { - using Argument = ReferenceGemmBiasActivation::Argument; - - float Run(const Argument& arg) - { - auto f_mk_kn_mn = [&](auto m, auto n) { - const int K = arg.a_m_k_.mDesc.GetLengths()[1]; - - float v_acc = 0; - - for(int k = 0; k < K; ++k) - { - float v_a; - float v_b; - - arg.a_element_op_(v_a, static_cast(arg.a_m_k_(m, k))); - arg.b_element_op_(v_b, static_cast(arg.b_k_n_(k, n))); - - v_acc += v_a * v_b; - } - - float v_c; - - arg.c_element_op_(v_c, v_acc, static_cast(arg.c0_n_(n))); - - arg.c_m_n_(m, n) = v_c; - }; - - make_ParallelTensorFunctor( - f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - - return 0; - } - - float Run(const device::BaseArgument* p_arg, - const StreamConfig& /* stream_config */ = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg)); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - bool IsSupportedArgument(const device::BaseArgument*) override { return true; } - - static auto MakeArgument(const Tensor& a_m_k, - const Tensor& b_k_n, - Tensor& c_m_n, - const Tensor& c0_n, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - { - return Argument{a_m_k, b_k_n, c_m_n, c0_n, a_element_op, b_element_op, c_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - virtual std::unique_ptr MakeInvokerPointer() - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "ReferenceGemmBiasActivation" - << std::endl; - // clang-format on - - return str.str(); - } -}; - -} // namespace host -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp deleted file mode 100644 index 99102a40d4e9a4c6360b3e6ecf9c8a345f5b38a9..0000000000000000000000000000000000000000 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp +++ /dev/null @@ -1,148 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include - -#include "ck/tensor_operation/gpu/device/device_base.hpp" - -#include "ck/library/utility/host_tensor.hpp" - -namespace ck { -namespace tensor_operation { -namespace host { - -template -struct ReferenceGemmBiasActivationAdd : public device::BaseOperator -{ - // Argument - struct Argument : public device::BaseArgument - { - Argument(const Tensor& a_m_k, - const Tensor& b_k_n, - Tensor& c_m_n, - const Tensor& c0_n, - const Tensor& c1_m_n, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - : a_m_k_{a_m_k}, - b_k_n_{b_k_n}, - c_m_n_{c_m_n}, - c0_n_{c0_n}, - c1_m_n_{c1_m_n}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op} - { - } - - const Tensor& a_m_k_; - const Tensor& b_k_n_; - Tensor& c_m_n_; - const Tensor& c0_n_; - const Tensor& c1_m_n_; - - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CElementwiseOperation c_element_op_; - }; - - // Invoker - struct Invoker : public device::BaseInvoker - { - using Argument = ReferenceGemmBiasActivationAdd::Argument; - - float Run(const Argument& arg) - { - auto f_mk_kn_mn = [&](auto m, auto n) { - const int K = arg.a_m_k_.mDesc.GetLengths()[1]; - - float v_acc = 0; - - for(int k = 0; k < K; ++k) - { - float v_a; - float v_b; - - arg.a_element_op_(v_a, static_cast(arg.a_m_k_(m, k))); - arg.b_element_op_(v_b, static_cast(arg.b_k_n_(k, n))); - - v_acc += v_a * v_b; - } - - float v_c; - - arg.c_element_op_(v_c, - v_acc, - static_cast(arg.c0_n_(n)), - static_cast(arg.c1_m_n_(m, n))); - - arg.c_m_n_(m, n) = v_c; - }; - - make_ParallelTensorFunctor( - f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - - return 0; - } - - float Run(const device::BaseArgument* p_arg, - const StreamConfig& /* stream_config */ = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg)); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - bool IsSupportedArgument(const device::BaseArgument*) override { return true; } - - static auto MakeArgument(const Tensor& a_m_k, - const Tensor& b_k_n, - Tensor& c_m_n, - const Tensor& c0_n, - const Tensor& c1_m_n, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - { - return Argument{ - a_m_k, b_k_n, c_m_n, c0_n, c1_m_n, a_element_op, b_element_op, c_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - virtual std::unique_ptr MakeInvokerPointer() - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "ReferenceGemmBiasActivationAdd" - << std::endl; - // clang-format on - - return str.str(); - } -}; - -} // namespace host -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp index 28132aa1ebd78178774caba73b91ba3c947b66be..ce2a83da6163fb6b19e30121a48734096857c508 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp index fedd4dce62cdcd31c2c582ab034291e9e098bc3c..6a48528c543b3ad1bcf94f92effcacbce40a8151 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp index 2bac5bc5c8f5cd6f1629e8cd108253fc6e9fee01..9994a2f9f7c862a1bf870103f999634e17682ecf 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_maxpool_bwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_maxpool_bwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3f1fc6165cd00daf9f2c8ae731dfc62a71d2ccd6 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_maxpool_bwd.hpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { +using namespace std; + +template +struct ReferenceMaxPoolBwd : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& dout, + const Tensor& indices, + Tensor& din, + ElementwiseOperation elementwise_op) + : dout_(dout), indices_(indices), din_(din), elementwise_op_(elementwise_op) + { + } + + const Tensor& dout_; + const Tensor& indices_; + Tensor& din_; + ElementwiseOperation elementwise_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + float Run(const Argument& arg) + { + int din_length = arg.din_.GetElementSpaceSize(); + int dout_length = arg.dout_.GetElementSpaceSize(); + std::vector buf(din_length, 0); + + for(int i = 0; i < dout_length; ++i) + { + int index = arg.indices_.mData[i]; + if(index >= 0 && index < din_length) + buf[index] += ck::type_convert(arg.dout_.mData[i]); + } + + for(int i = 0; i < din_length; ++i) + arg.din_.mData[i] = ck::type_convert(buf[i]); + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& dout, + const Tensor& indices, + Tensor& din, + ElementwiseOperation elementwise_op) + { + return Argument{dout, indices, din, elementwise_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceMaxPoolBwd" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..696fb5eaf1413755fd3990f5db4e9b34f72fe654 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp @@ -0,0 +1,345 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferencePoolingFwd : public device::BaseOperator +{ + using ReduceOperation = typename ck::reduce_binary_operator::opType; + + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& in, + Tensor& out, + Tensor& out_indices, + const std::vector& window_spatial_lengths, + const std::vector& window_strides, + const std::vector& in_left_pads, + const std::vector& /*in_right_pads*/) + : in_(in), + out_(out), + out_indices_(out_indices), + window_spatial_lengths_(window_spatial_lengths), + window_strides_(window_strides), + in_left_pads_(in_left_pads), + reduceLength_(1) + { + static_for<0, WindowRank, 1>{}( + [&](auto I) { reduceLength_ *= window_spatial_lengths[I]; }); + } + + const Tensor& in_; + Tensor& out_; + Tensor& out_indices_; + const std::vector& window_spatial_lengths_; + const std::vector& window_strides_; + const std::vector& in_left_pads_; + int reduceLength_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + float RunPooling3dFwd(const Argument& arg) + { + + auto elementwise_ops = + ck::reduce_unary_operator::GetElementwiseOperator( + arg.reduceLength_); + + auto in_elementwise_op = std::get<0>(elementwise_ops); + auto acc_elementwise_op = std::get<1>(elementwise_ops); + + if constexpr(!OutputIndex) + { + using Accumulation = ck::detail:: + AccumulateWithNanCheck; + + auto f_ncdhw = [&](auto n, auto c, auto do_, auto ho, auto wo) { + auto accuVal = ReduceOperation::template GetIdentityValue(); + + for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z) + { + ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0]; + for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y) + { + ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1]; + for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x) + { + ck::index_t wi = + wo * arg.window_strides_[2] + x - arg.in_left_pads_[2]; + if(di >= 0 && + di < static_cast(arg.in_.mDesc.GetLengths()[2]) && + hi >= 0 && + hi < static_cast(arg.in_.mDesc.GetLengths()[3]) && + wi >= 0 && + wi < static_cast(arg.in_.mDesc.GetLengths()[4])) + { + ComputeDataType currVal = ck::type_convert( + arg.in_(n, c, di, hi, wi)); + + in_elementwise_op(currVal, currVal); + + Accumulation::Calculate(accuVal, currVal); + } + } + } + } + acc_elementwise_op(accuVal, accuVal); + + arg.out_(n, c, do_, ho, wo) = ck::type_convert(accuVal); + }; + + make_ParallelTensorFunctor(f_ncdhw, + arg.out_.mDesc.GetLengths()[0], + arg.out_.mDesc.GetLengths()[1], + arg.out_.mDesc.GetLengths()[2], + arg.out_.mDesc.GetLengths()[3], + arg.out_.mDesc.GetLengths()[4])( + std::thread::hardware_concurrency()); + } + else + { + using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck; + + auto f_ncdhw = [&](auto n, auto c, auto do_, auto ho, auto wo) { + auto accuVal = ReduceOperation::template GetIdentityValue(); + IndexDataType accuIndex = 0; + + for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z) + { + ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0]; + for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y) + { + ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1]; + for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x) + { + ck::index_t wi = + wo * arg.window_strides_[2] + x - arg.in_left_pads_[2]; + if(di >= 0 && + di < static_cast(arg.in_.mDesc.GetLengths()[2]) && + hi >= 0 && + hi < static_cast(arg.in_.mDesc.GetLengths()[3]) && + wi >= 0 && + wi < static_cast(arg.in_.mDesc.GetLengths()[4])) + { + ComputeDataType currVal = ck::type_convert( + arg.in_(n, c, di, hi, wi)); + IndexDataType currIndex = + arg.in_.GetOffsetFromMultiIndex(n, c, di, hi, wi); + + in_elementwise_op(currVal, currVal); + + Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex); + } + } + } + } + + acc_elementwise_op(accuVal, accuVal); + + arg.out_(n, c, do_, ho, wo) = ck::type_convert(accuVal); + arg.out_indices_(n, c, do_, ho, wo) = accuIndex; + }; + + make_ParallelTensorFunctor(f_ncdhw, + arg.out_.mDesc.GetLengths()[0], + arg.out_.mDesc.GetLengths()[1], + arg.out_.mDesc.GetLengths()[2], + arg.out_.mDesc.GetLengths()[3], + arg.out_.mDesc.GetLengths()[4])( + std::thread::hardware_concurrency()); + }; + + return 0; + } + + float RunPooling2dFwd(const Argument& arg) + { + + auto elementwise_ops = + ck::reduce_unary_operator::GetElementwiseOperator( + arg.reduceLength_); + + auto in_elementwise_op = std::get<0>(elementwise_ops); + auto acc_elementwise_op = std::get<1>(elementwise_ops); + + if constexpr(!OutputIndex) + { + using Accumulation = ck::detail:: + AccumulateWithNanCheck; + + auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { + auto accuVal = ReduceOperation::template GetIdentityValue(); + + for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y) + { + ck::index_t hi = ho * arg.window_strides_[0] + y - arg.in_left_pads_[0]; + for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x) + { + ck::index_t wi = wo * arg.window_strides_[1] + x - arg.in_left_pads_[1]; + if(hi >= 0 && + hi < static_cast(arg.in_.mDesc.GetLengths()[2]) && + wi >= 0 && + wi < static_cast(arg.in_.mDesc.GetLengths()[3])) + { + ComputeDataType currVal = + ck::type_convert(arg.in_(n, c, hi, wi)); + + in_elementwise_op(currVal, currVal); + + Accumulation::Calculate(accuVal, currVal); + } + } + } + + acc_elementwise_op(accuVal, accuVal); + arg.out_(n, c, ho, wo) = ck::type_convert(accuVal); + }; + + make_ParallelTensorFunctor(f_nchw, + arg.out_.mDesc.GetLengths()[0], + arg.out_.mDesc.GetLengths()[1], + arg.out_.mDesc.GetLengths()[2], + arg.out_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + } + else + { + using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck; + + auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { + auto accuVal = ReduceOperation::template GetIdentityValue(); + IndexDataType accuIndex = 0; + + for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y) + { + ck::index_t hi = ho * arg.window_strides_[0] + y - arg.in_left_pads_[0]; + for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x) + { + ck::index_t wi = wo * arg.window_strides_[1] + x - arg.in_left_pads_[1]; + if(hi >= 0 && + hi < static_cast(arg.in_.mDesc.GetLengths()[2]) && + wi >= 0 && + wi < static_cast(arg.in_.mDesc.GetLengths()[3])) + { + ComputeDataType currVal = + ck::type_convert(arg.in_(n, c, hi, wi)); + + IndexDataType currIndex = + arg.in_.GetOffsetFromMultiIndex(n, c, hi, wi); + + in_elementwise_op(currVal, currVal); + + Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex); + } + } + } + + acc_elementwise_op(accuVal, accuVal); + arg.out_(n, c, ho, wo) = ck::type_convert(accuVal); + arg.out_indices_(n, c, ho, wo) = accuIndex; + }; + + make_ParallelTensorFunctor(f_nchw, + arg.out_.mDesc.GetLengths()[0], + arg.out_.mDesc.GetLengths()[1], + arg.out_.mDesc.GetLengths()[2], + arg.out_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + }; + + return 0; + } + + float Run(const Argument& arg) + { + // TODO - support generic pooling + if constexpr(InOutRank == 5 && WindowRank == 3) + return RunPooling3dFwd(arg); + else if constexpr(InOutRank == 4 && WindowRank == 2) + return RunPooling2dFwd(arg); + else + throw std::runtime_error("Only support pooling3d or pooling2d so far"); + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& in, + Tensor& out, + Tensor& out_indices, + const std::vector& window_spatial_lengths, + const std::vector& window_strides, + const std::vector& in_left_pads, + const std::vector& in_right_pads) + { + return Argument{in, + out, + out_indices, + window_spatial_lengths, + window_strides, + in_left_pads, + in_right_pads}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferencePoolingFwd" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_reduce.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_reduce.hpp index c04baca57491ddce5f4dd09469ddb116840bcecd..944f34007ef0e030e2017688477c27602df15391 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_reduce.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_reduce.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp index a4fd46c932c08f3d51df8db212e167d66b755c27..9916a03b9c8f2eabe20ef0c27d78def6f4b9d6a5 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp index b6a9b0fb5ee4a3135750dad0d0818a0f6e439f3a..f949260ca4552b83ab6e8edafee9190d4f95cc64 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd.hpp b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd.hpp index df4fca6562755012200ff888677bcac1c90129a2..0b7887efb6e05f34b83a5bf684b5fedd3d78c385 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef NAIVE_CONV_FWD_HPP #define NAIVE_CONV_FWD_HPP diff --git a/library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp b/library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp index 20df1b3616a016029f4083bd821b4e286b7b727c..f57fed9c07c7843af0f77db7991ce077b3021245 100644 --- a/library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index 104b21a3ec4ab0f43fdeb02b5b3fa5a4a0b61e1e..851d9e497ffe6a0f7b506b43be841ccb8554bd6e 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -26,6 +26,7 @@ using Empty_Tuple = ck::Tuple<>; using F16_Tuple = ck::Tuple; using F16_F16_Tuple = ck::Tuple; +using F64_Tuple = ck::Tuple; using F32_Tuple = ck::Tuple; using I32_Tuple = ck::Tuple; using I32_F32_Tuple = ck::Tuple; @@ -85,6 +86,7 @@ using GK_GK_Tuple = ck::Tuple; // pointwise functor using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Relu = ck::tensor_operation::element_wise::Relu; +using TanH = ck::tensor_operation::element_wise::TanH; using Scale = ck::tensor_operation::element_wise::Scale; using Bilinear = ck::tensor_operation::element_wise::Bilinear; using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; @@ -94,6 +96,7 @@ using FastGelu = ck::tensor_operation::element_wise::FastGelu; using AddMultiply = ck::tensor_operation::element_wise::AddMultiply; using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; using Gelu = ck::tensor_operation::element_wise::Gelu; +using Swish = ck::tensor_operation::element_wise::Swish; template using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; @@ -102,6 +105,10 @@ template using Add_Activation_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; +template +using Add_Mul_Activation_Mul_Clamp = + ck::tensor_operation::element_wise::Add_Mul_Activation_Mul_Clamp; + template using Activation_Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp; @@ -109,6 +116,10 @@ template using Add_Activation_Mul2_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp; +template +using Add_Mul2_Activation_Mul_Clamp = + ck::tensor_operation::element_wise::Add_Mul2_Activation_Mul_Clamp; + template struct DeviceOperationInstanceFactory; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp index 0655fd92e44c5c6400c5a66e6ee9e21f6e50da71..c3c8c0e5a76329071970a2ee9e81f297fbf11292 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp @@ -1,10 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include - +#include +#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp index 495c5f884fd633898867a3e0a5fc121326527c8e..73f6004257135b75ac0bb1cd79f92e0d9d78c3ef 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp @@ -1,10 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include - +#include +#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp" diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_permute.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_permute.hpp index 59d50e1bd23448029054d64f9d7d557870ec9310..70bff278997501254d47595f28cb0d2d01c679e2 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_permute.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_permute.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp index 0aa7a5aa3ddf860f15ce3e1dfd5579cd46554981..33653a30800ebb7a22d333d0ccc1bd8ed4506df4 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp @@ -1,10 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include - +#include +#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp index a6dcfa30d3efd2d78562c1829ffb72b4eaabca9b..28ccf61a3a1d46f3dfb6c7158192416622f7907a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp @@ -1,10 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include - +#include +#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp" diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ae12f4c7ac36aa0bff962f4657ab708a0ac4085d --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp @@ -0,0 +1,337 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instances( + std::vector>>& instances); + +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = DeviceBatchedGemmMultiD; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instances(op_ptrs); + add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instances(op_ptrs); + add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances(op_ptrs); + add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instances(op_ptrs); + add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instances( + op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instances(op_ptrs); + add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instances(op_ptrs); + add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances(op_ptrs); + add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instances(op_ptrs); + add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instances( + op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp index 8a0b1b1fa7de54954a5bc375241bb7a0686b2e49..0a30c210dd818b7812a0ae52690afb5a7a991ea2 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp index 89df1a7a0d44394d190973ad787b176da8002851..2ff64675c904f0f6e30c436e6c44b57211c00f89 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp @@ -1,10 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include - +#include +#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp index c84ffcff8cbb8d4820f88dbfa3e7de6a04961bb8..0e1f6f04e8c58812651d3d549601952050fb73dd 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp index 8e40d60c17b8dc8d743e29cc1472a1e9eed9bdc1..8fd1c7665d143e9a0fb07f47f07ead18aa33aa88 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_infer.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_infer.hpp index 342ade69cdf499c8a8a960038c288d4921a813b6..f6f4df7e2ec0e34b6ac7a4b356649289f6a30c91 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_infer.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_infer.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp index a0cea7e390aab9402db40a439ffb43ccfb335bde..2ed8255f6d66b72cba17492351824a0007afe6dc 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp @@ -1,12 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include #include #include - #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" @@ -19,6 +17,7 @@ namespace tensor_operation { namespace device { namespace instance { +// float void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( std::vector>>& instances); +// double +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance( + std::vector>>& instances); + // Contraction + Bilinear template && is_same_v && + is_same_v && is_same_v) + { + if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2) + { + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance( + op_ptrs); + } + } + return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp index e921ecd47aa5b08d2f6f63ee2bad30c40608b05c..5d9567731dae2ea2d491cbeade861486f36a1743 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp @@ -1,12 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include #include #include - #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" @@ -19,6 +17,7 @@ namespace tensor_operation { namespace device { namespace instance { +// float void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance( std::vector>>& instances); +// double +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance( + std::vector>>& instances); + // Contraction + Scale template && is_same_v && + is_same_v) + { + if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2) + { + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance( + op_ptrs); + } + } + return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp index ec5d18fc214c34c1e972aa5fc50726cd41f8180c..1efe073669846e40274897ec33cfbf932cb31715 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp @@ -1,10 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include - +#include +#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp" diff --git a/library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp index 62f28c9b11d287f355d5feeeb24bf7863ac96851..dd8c1987cfe65bfba4e792f269004f52cdd912ad 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp @@ -1,10 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include - +#include +#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" diff --git a/library/include/ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp index 381a015eb08aeede66fb903be5c517e8e4ab9664..b03693b00aabceef8614a56e91c0ce9b89207def 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp @@ -1,16 +1,14 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include - +#include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" namespace ck { namespace tensor_operation { @@ -30,20 +28,34 @@ template -auto get_device_normalize_from_mean_meansquare_instances() +struct DeviceOperationInstanceFactory, + ck::Tuple, + Normalize, + 2>> { - std::vector op_ptrs; + using DeviceOp = DeviceElementwise< + ck::Tuple, + ck::Tuple, + Normalize, + 2>; - if constexpr(is_same::value && is_same::value && - is_same::value && is_same::value && - is_same::value && is_same::value) + static auto GetInstances() { - ck::tensor_operation::device::instance:: - add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances(op_ptrs); - } - - return op_ptrs; -} + std::vector> op_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value && + is_same::value && + is_same::value && is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances(op_ptrs); + } + + return op_ptrs; + }; +}; } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp index 682f54675982e9338823a814c46075242534de1d..b15139510935586a930cd5a7c5a091299dfddecc 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp @@ -1,10 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include - +#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" diff --git a/library/include/ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp index c87ae159bee2635c903490b9e09096be8c242c8c..a1c006cf62a063a742a8dd29365a0aefe359282d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp index e230507e7e30a605bac4ed5d27cea454739ba24a..adac7d0dcf4a7bee2a0315add52d4c85251208dc 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp @@ -1,12 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include #include #include - #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm.hpp" @@ -24,21 +22,41 @@ void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( DeviceGemm>>& instances); +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances( + std::vector>>& + instances); + void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( std::vector>>& instances); +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances( + std::vector>>& + instances); + void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( std::vector>>& instances); +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances( + std::vector>>& + instances); + void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( std::vector>>& instances); +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances( + std::vector>>& + instances); + void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances( std::vector>>& @@ -65,21 +83,41 @@ void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances( DeviceGemm>>& instances); +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances( + std::vector>>& + instances); + void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances( std::vector>>& instances); +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances( + std::vector>>& + instances); + void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances( std::vector>>& instances); +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances( + std::vector>>& + instances); + void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances( std::vector>>& instances); +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances( + std::vector>>& + instances); + void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( std::vector>>& @@ -297,6 +335,7 @@ struct DeviceOperationInstanceFactory< { add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && @@ -304,6 +343,7 @@ struct DeviceOperationInstanceFactory< { add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs); } @@ -312,6 +352,7 @@ struct DeviceOperationInstanceFactory< { add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && @@ -319,6 +360,7 @@ struct DeviceOperationInstanceFactory< { add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); } } @@ -354,24 +396,28 @@ struct DeviceOperationInstanceFactory< { add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs); add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs); add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs); add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs); add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(op_ptrs); } } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp index 90b6e11b9b84bdaf79edbcb057693f2bd4aae441..99b2ad13152f0baa82872add5efdeb9b6a1988a8 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp @@ -1,12 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include #include #include - #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp index 554437f4903bcae96f68f2a53097a6eaf88367f3..fd3550c2f01b733602e6ae84dcca34f74255c9fe 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_multiply.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_multiply.hpp index c07ca3134bb889e5d15fc5df521979d5ea2efe46..481915d00b7416a444eabdf92289d275a9f60ef9 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_multiply.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_multiply.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp index 7beae83cdc1799814a2e7f4a1a4e7fa9e2d59a08..de21f325a6711d1fe746905dfb7686ddb4ff7e7c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp index ef70504f29b717ed70a494926d256f8b98a35e47..f80efd5511a379a300078bd8d2cb63bcc0ce0154 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp @@ -1,12 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include #include #include - #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp index fbc5df98a4ec116976d491835e8a8411b609c992..09b1c2190facdd84e09e5e1d66f32b4212ef0ec8 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp index 8986a793444c8f73679d892f4d55adb02bc61834..534875151105ee04a147584e61623320ee18ce57 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp @@ -1,10 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include - +#include +#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index 81b2b4fcf37cb1f0d05415eb7d56587f0a29f380..0b20e19a2a980975d057798549055d2c4a8d0ef0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -30,6 +30,76 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector>>& instances); + template && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs); + } + } + else if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs); + } } return op_ptrs; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index ef6920e52a2096a3f61e40435c6fa3ce64a17323..377ce083c74b2419a34b1c966c951d659bdaeedd 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index a8df7f0d5bfed1948ee5b0d8e79fe17cc19f0794..627a5ae2aa9d49d1e9936d094f4240eb9907c304 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -1,10 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include - +#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -117,20 +117,6 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( PassThrough, PassThrough>>>& instances); -void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances( - std::vector>>& instances); - void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances( std::vector>>& instances); -void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances( +// grouped conv2d forward, NHWGC/GKYXC/NHWGK +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( std::vector>>& instances); -// grouped conv2d forward, NHWGC/GKYXC/NHWGK + void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( std::vector>>& instances); +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances); + // grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances( std::vector>>& instances); -// grouped conv3d forward, NDHWGC/KZYXGC/NDHWGK -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances( +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( std::vector>>& instances); -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances( +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( std::vector>>& instances); -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances( +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector>>& instances); -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances( +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances( std::vector && is_same_v && - is_same_v) - { - add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs); - add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs); - } } else if constexpr(NumDimSpatial == 2 && is_same_v && is_same_v && is_same_v) @@ -398,7 +393,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - // no instance + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) @@ -409,12 +404,7 @@ struct DeviceOperationInstanceFactory && is_same_v) { - // no instance - } - else if constexpr(is_same_v && is_same_v && - is_same_v) - { - // no instance + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs); } } else if constexpr(NumDimSpatial == 3 && is_same_v && @@ -443,28 +433,28 @@ struct DeviceOperationInstanceFactory && - is_same_v && is_same_v) + is_same_v && is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(op_ptrs); } } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp index fa7462dbd955c1f825846ccc5c1dba0baae53916..b482e97ee0386ba8d7940481f93512e797aa2f4f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp @@ -1,10 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include - +#include +#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" @@ -68,6 +68,58 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances( + std::vector>>& instances); + template ) { add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances( + op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances( + op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/normalization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/normalization.hpp index 55c67b7623b85c03381833431338f6be083fbd7b..778f625d84cbba947bd4d81ecc59bb94ebfe195b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/normalization.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/normalization.hpp @@ -1,10 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include - +#include +#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_normalization.hpp" diff --git a/library/include/ck/library/tensor_operation_instance/gpu/normalization_swish.hpp b/library/include/ck/library/tensor_operation_instance/gpu/normalization_swish.hpp new file mode 100644 index 0000000000000000000000000000000000000000..23917752997b42e7c67e667247e38a1762a50a8d --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/normalization_swish.hpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// FP16 +void add_device_normalization_rank_5_3_swish_f16_instances( + std::vector>>&); + +// FP32 +void add_device_normalization_rank_5_3_swish_f32_instances( + std::vector>>&); + +// [x, gamma, beta, y] = [f16, f32, f32, f16] +void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances( + std::vector>>&); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceNormalization> +{ + using DeviceOp = DeviceNormalization; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(Rank == 5 && NumReduceDim == 3) + { + add_device_normalization_rank_5_3_swish_f16_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(Rank == 5 && NumReduceDim == 3) + { + add_device_normalization_rank_5_3_swish_f32_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(Rank == 5 && NumReduceDim == 3) + { + add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/pool2d_fwd.hpp b/library/include/ck/library/tensor_operation_instance/gpu/pool2d_fwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ccb5cb5a9eb9e74932789cfa25f299707746bce6 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/pool2d_fwd.hpp @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr auto InOutRank = 4; +static constexpr auto WindowRank = 2; + +static constexpr auto MaxOp = ck::ReduceTensorOp::MAX; +static constexpr auto AvgOp = ck::ReduceTensorOp::AVG; + +// FP16 +void add_device_pool2d_fwd_nhwc_f16_instances( + std::vector< + std::unique_ptr>>&); + +void add_device_pool2d_fwd_nhwc_f16_instances( + std::vector< + std::unique_ptr>>&); + +// FP16 - return index +void add_device_pool2d_fwd_nhwc_index_f16_instances( + std::vector< + std::unique_ptr>>&); + +// FP32 +void add_device_pool2d_fwd_nhwc_f32_instances( + std::vector< + std::unique_ptr>>&); + +void add_device_pool2d_fwd_nhwc_f32_instances( + std::vector< + std::unique_ptr>>&); + +// FP32 - return index +void add_device_pool2d_fwd_nhwc_index_f32_instances( + std::vector< + std::unique_ptr>>&); + +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = DevicePoolFwd; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(OutputIndex && ReduceOpId == MaxOp) + { + add_device_pool2d_fwd_nhwc_index_f16_instances(op_ptrs); + } + else + { + add_device_pool2d_fwd_nhwc_f16_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(OutputIndex && ReduceOpId == MaxOp) + { + add_device_pool2d_fwd_nhwc_index_f32_instances(op_ptrs); + } + else + { + add_device_pool2d_fwd_nhwc_f32_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp b/library/include/ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3a006b00a5b04c2689d2a262aecee5623f02d195 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr auto InOutRank = 5; +static constexpr auto WindowRank = 3; + +static constexpr auto MaxOp = ck::ReduceTensorOp::MAX; +static constexpr auto AvgOp = ck::ReduceTensorOp::AVG; + +// FP16 +void add_device_pool3d_fwd_ndhwc_f16_instances( + std::vector< + std::unique_ptr>>&); + +void add_device_pool3d_fwd_ndhwc_f16_instances( + std::vector< + std::unique_ptr>>&); + +// FP16 - return index +void add_device_pool3d_fwd_ndhwc_index_f16_instances( + std::vector< + std::unique_ptr>>&); + +// FP32 +void add_device_pool3d_fwd_ndhwc_f32_instances( + std::vector< + std::unique_ptr>>&); + +void add_device_pool3d_fwd_ndhwc_f32_instances( + std::vector< + std::unique_ptr>>&); + +// FP32 - return index +void add_device_pool3d_fwd_ndhwc_index_f32_instances( + std::vector< + std::unique_ptr>>&); + +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = DevicePoolFwd; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(OutputIndex && ReduceOpId == MaxOp) + { + add_device_pool3d_fwd_ndhwc_index_f16_instances(op_ptrs); + } + else + { + add_device_pool3d_fwd_ndhwc_f16_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(OutputIndex && ReduceOpId == MaxOp) + { + add_device_pool3d_fwd_ndhwc_index_f32_instances(op_ptrs); + } + else + { + add_device_pool3d_fwd_ndhwc_f32_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp index 2fd7ce22f7010f9e7b738395b5f670ac14aa9924..2ed4b0d5f880d5e8f0e8ba26e7e2ba541201c63f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp index 57c971e52e838135665657e22156acea0c29ddd5..8a96f6707e35344762003a958bc7afb9b937c86f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -17,14 +17,14 @@ namespace tensor_operation { namespace device { namespace instance { -// grouped conv2d forward, GNHWC/GKYXC/GNHWK +// grouped conv2d forward, NHWGC/GKYXC/NHWGK void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances( std::vector< std::unique_ptr>>>& instances); +void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances( + std::vector< + std::unique_ptr>>>& + instances); + void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances( std::vector< std::unique_ptr>>>& instances); +void add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances( + std::vector< + std::unique_ptr>>>& + instances); + +// piecewise activation function template > op_ptrs; - if constexpr(NumDimSpatial == 2 && is_same_v && + if constexpr(NumDimSpatial == 2 && is_same_v && is_same_v && is_same_v && - is_same_v) + is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) @@ -145,6 +178,67 @@ struct DeviceOperationInstanceFactory +struct DeviceOperationInstanceFactory>> +{ + using DeviceOp = DeviceGroupedConvFwdMultipleD>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v) + { + add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(op_ptrs); + add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(op_ptrs); + } + } + } + + return op_ptrs; + } +}; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp index 9f8ac9b7b1127731bf994763d5fb388227f88d69..e17bea5fdf05ba19bb9e36de450865b67c406ff7 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -17,14 +17,14 @@ namespace tensor_operation { namespace device { namespace instance { -// grouped conv2d forward, GNHWC/GKYXC/GNHWK +// grouped conv2d forward, NHWGC/GKYXC/NHWGK void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances( std::vector< std::unique_ptr>>>& instances); +void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances( + std::vector>>>& + instances); + void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances( std::vector< std::unique_ptr>>>& instances); +void add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances( + std::vector>>>& + instances); + +// piecewise activation function template > op_ptrs; - if constexpr(NumDimSpatial == 2 && is_same_v && + if constexpr(NumDimSpatial == 2 && is_same_v && is_same_v && is_same_v && - is_same_v) + is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) @@ -145,6 +176,67 @@ struct DeviceOperationInstanceFactory +struct DeviceOperationInstanceFactory>> +{ + using DeviceOp = DeviceGroupedConvFwdMultipleD>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v) + { + add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(op_ptrs); + add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances(op_ptrs); + } + } + } + + return op_ptrs; + } +}; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp index 089343fe6f7ffcdd09d79b37856cf78300b39421..9236b5c79e3079b0ee1f7c68091f15866d692a3a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -17,13 +17,13 @@ namespace tensor_operation { namespace device { namespace instance { -// grouped conv2d forward, GNHWC/GKYXC/GNHWK +// grouped conv2d forward, NHWGC/GKYXC/NHWGK void add_device_conv2d_dl_perchannel_quantization_int8_instances( std::vector> op_ptrs; - if constexpr(NumDimSpatial == 2 && is_same_v && + if constexpr(NumDimSpatial == 2 && is_same_v && is_same_v && is_same_v && - is_same_v) + is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp index e570027eb5b0e092a37e3749be6a067ab0cac435..7c1eb4e429c727099950cb8d722051431c4064bb 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -17,13 +17,13 @@ namespace tensor_operation { namespace device { namespace instance { -// grouped conv2d forward, GNHWC/GKYXC/GNHWK +// grouped conv2d forward, NHWGC/GKYXC/NHWGK void add_device_conv2d_dl_perlayer_quantization_int8_instances( std::vector> op_ptrs; - if constexpr(NumDimSpatial == 2 && is_same_v && - is_same_v && is_same_v) + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp index 550a7b034509af43412c0faf5ac2fb52fb8b2eb3..9930b1a6fa04a5225432a6f8b0499293189b66be 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp index 2cdbfbb0c2ee6da3ba452fe43998537dd64f4bfc..c9c1475f182578f83288a3b0c1e9db942af494a7 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp index 4e3fa81f75f41c842688f51310c6c00e50d03082..4dd5569cee34aef47aca01c967ba80aca4b9ec7c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp index 7ca8bc258ae3d163586ca60e85f467d974e3f901..d52310a3f3b1b8aecfcb0bfc97f9641943cd9d55 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.hpp index 37398146b873403cf68999cf4c6bae3191430a06..025500764ea6730b784cbafec1083cbc2cc0af51 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.hpp index 5eacd358c87ae5d0a82e9c06ca82c74492c744f6..2314e94980fb927b034249ecbafa7877056e6cae 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.hpp index 94ae02bf3dccaa6df9704cc3f114e2a41abda3c8..5c2bff16cccf6377d347992ab1224bcf27ecff0a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.hpp index e41e8de6a50a5d0b37177149529edd9d289e6ff4..a1279eeccbf0e38f3862b54aad87204ec44bf2da 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_amax.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_amax.hpp index 99762aa64b320e3f6d80e6a380f8b7bff4221e41..aff6a2542bf7a195b090c3a3ee8d1ff4c42f2776 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_amax.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_amax.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_max.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_max.hpp index 1fc557a95dbd9b12081583f8ca42c35fc415a20c..be8da2243fae00df87d2d46d34f1d709dfa32e46 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_max.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_max.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_min.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_min.hpp index ca3ba4eb0beb4be251a358de6bbd8f22b1341859..652984ae22171b9e0b1941d8379df994cbc30a9c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_min.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_min.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_add.hpp index 28a85782d13971cf7b3b49210e291ea569ab27e9..be60d1b32db50457684a8059f25cc29dc34f6467 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_avg.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_avg.hpp index ba74400793fdded7b24f82beb49810f5ac0d54b1..27e3aa53e38911afaf2d47c6900a1fb09ab4ec89 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_avg.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_avg.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_norm2.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_norm2.hpp index f5c813de78156975f2bfe4be4312bf05041620dd..f7f4870a97e98b473773ec70e167bc4818ab469a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_norm2.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_norm2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_add.hpp index e25b6e84938e8ebeac212e95a78211b6801db7f0..790b5a92ad7532a26d4663b37189e6d9346fea27 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.hpp index a264d11262508751a4121f257b0837b1af7ab4a1..ec3bc852e8e97e517aa6f18be36b3a40646fa542 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_avg.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_avg.hpp index 8b1d8c95baba5f56139ac26b0539d977f3a5e3bc..8c0c065670273a5a7ade3ef4f234236589f8ae3e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_avg.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_avg.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_max.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_max.hpp index 49a60d88c3fbe4bd5eb6fcadcd8120c722ab3383..8631495df12d6cf9a9154b2a976c80a719352d48 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_max.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_max.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_min.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_min.hpp index 04a7c2d238f9cb97add1b6b615d7dda83f84c933..59849c2d459b6fd8fbd81adad00fac4493e19686 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_min.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_min.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_norm2.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_norm2.hpp index d0feefb50d02d0e40e022207fb5200ba28dc8b43..33cbea85efefe3d7509888dd3c15d07dca28dc5c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_norm2.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_norm2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_add.hpp index 35f35f202c3055a49652c9aa8542fac504d024da..386182957514384e6c68fbacc961ade931817392 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_avg.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_avg.hpp index 63eb7221b5025f978c40383486a66bf45ad2caf4..f7c05b20a7ce8ff8acf13de3b33b2f52033ce2cb 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_avg.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_avg.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_norm2.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_norm2.hpp index 1bca3c1f432b4b403c6e71a8831e9c03acc83368..2dfecc6b07b2819efd80135a0269d820bb2abaa4 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_norm2.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_norm2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.hpp index 1791a186f5a3fa41481f2674f1e47a0f07fdb5ab..a687938965e34a45c05d356e94983494fb8b6570 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.hpp index 3f56c057ef636a67a66f7757c4f9dae4bbe03a43..cf9c268a6476be9b62e76742a7d35293dffcbf0d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.hpp index a3b8bcf9a0a12d1bd10dc9304c045af3ea66d481..aeb578f948cf2e3d809e44d2975baceb7d20112f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.hpp index 18e0e084d768f2f1a2a11b5f384128910021d5c6..73480262c66e99af5afd1d360189585f7fff7410 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.hpp index 4a106463a31eb2a259b02161f29745bac9490d99..74293553e1a168b0a75592bc4f77d9325cfafe60 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.hpp index 23e1c49fe927cee5384b0a37ad3dc26206a5c694..8a91f76b34faf0b52378ebdb3462d85c20631a12 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.hpp index 62e2d24f0248319527e8dd0d905098252aa7f1f4..0ff2c30f346047d092d0ecd9b2e952fd99fe78d6 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.hpp index 18a54d86862de4837c843301e53d6ff90d3ba419..932299008eba691e5b2f704ca47a00b9159ab5ae 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.hpp index 9f408906a7c01d4e1d0eb7eb9eb12c7896b9cafe..c902e80a2c731f76f4608d4501ad711a1f1a68be 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.hpp index c40052562f8deba18c7ef5170e2a1f3236cf217d..e7a1369277cddb95cd9da2d06a2b16fdad643423 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.hpp index 532bfb417e5c9ae706b8d227bccaf688f50fdcf9..2d1ab69fc75f74ae9128824f842b1e7f89501db3 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp index 8c08e5ef2f0365e36125d141f95474573dce5bf6..9f782c11b6dd306c14e1027ff083baa17821b5f5 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp index 0d08377a226ab06f974e74d63fb9761e579918d2..bb45295ed82e8ed5d6602e3642002f6783bb31fd 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.hpp index 4cdd45e85b3781f08ea257b73d5780fbd1aa406a..5814b84bb7b295e261b86817b4698ab01284ad41 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.hpp index a36cafb27a02bce42b08bf70c512b2cbbef3e2b9..f1a609a6f483d979c394ce5b3bd74b19eb14a803 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.hpp index 13b0780497fafa06281a5dd4ae196c6695844ac2..489bef1f561102eb05e9bfc58da3cf45e6bcf54a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.hpp index 75e1f102421f5cf916aeeea39a7bf2173341aec8..2507d82308da420c51d3c4b1e2a6071aedf186a0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_add.hpp index 00ec17fadad0527bcf0713350c189f5aef228728..df4b7b7b493516a1e01fdc2be435f255bf5cd024 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_avg.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_avg.hpp index 7b762bc93216ba9cb525849ddd5c618fd097bfb9..2748b9b9d0434949c50c1c0e4f2e76f2b25bb8af 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_avg.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_avg.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_add.hpp index 2a2b284b226a37d07c38db64b3ab2686bf2c30ce..26230d5456c533175b62e54ead132486c80b494b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_avg.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_avg.hpp index 444d8ddc86b65cce2961874ee7c64b69a9573d2f..d5e1499aefc0ada2270f6efdda5776613f9a9c5f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_avg.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_avg.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.hpp index f3c07017616adaa8bbd529569d2d951b1a9f2a4e..cbc8befb3d39c36a199159b9a1d3759e17f2ce4a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.hpp index c57edd0846644d1d43c2d6578fe57b3b7795f2ff..574cc3dd3e015b4194cb0f584f61dbc8afb6e4e5 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp index f77c50a8e8c8a7be78c6a864175898fda22ac3ab..04bab9d283fa83f57eeae64831966fe115dde9c2 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -90,6 +90,7 @@ void add_device_reduce_instance_threadwise( AccElementwiseOp, PropagateNan, OutputIndex, + false, false, // HaveIndexInputIfOutputIndex cfg1::BlockSize_, cfg2::MThreadSliceSize_, diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_add.hpp index 8960ba7c5b66b831b555584fe2bf4af4ee11c024..6c50d222ecaf19d1026f7cab1478b413de77af10 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_amax.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_amax.hpp index 95d9c072657cf08dbeda045ba4904bc4b1650883..5e91f53b0f99864a874fd81f9889cf721780c980 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_amax.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_amax.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_avg.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_avg.hpp index dd6734061eaadbaaff1f8134e96acddf29f22d41..171803156051497dce6d05bea27663fdc831dbbb 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_avg.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_avg.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_max.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_max.hpp index 85f75110dfb4ce248cce75434f425bdf7c0c3e39..03bff5912dac549457c5c51b1e5b501b9ef18f50 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_max.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_max.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_min.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_min.hpp index 7f62f4e01064028156e6d00dbe70ce13102c1d11..676af19336c5ca205ccabf31028bd57dfe577d49 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_min.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_min.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_norm2.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_norm2.hpp index eee771b13373499d313d4c072608a86d966ebdc4..c67c0a41f85154a9ef5c230958e0e88893d31915 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_norm2.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_norm2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.hpp index 64f1e9c22be3cc0a8b8f5cf2f40e795a112243be..c9058a9b1d33a2f11f0da600edddb436a2941305 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.hpp index 078561e1539bdc5113826c6220807c858186c8c2..1fbd3c6e7de827c71e01a601b7edf724cef9a6f2 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.hpp index 5a9144186bb53fa8a182b5034e0306be2edc7bbc..89496a4ed16926eec96ae1b5ea560b3b1fd90637 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.hpp index dc4740aa3de33e40f3c2bd8506b7d27ff92669fa..fb79ce8adf39c6fd732d59c4a19f82d8d253cf07 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.hpp index 9ecc96797ffe267ac5d5842450ecc604b65e14b0..7915bd2630a25aab209cb42ba69087b54a1a0d76 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.hpp index ccce78e2f1541fd948c6ac6866b747ce6d08e928..d769fe8154918b80a2fcbadea95ffe31c57cb128 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_add.hpp index 6d3749d868a8fde91ee7c96c57fbf7b6f450eb6b..49938d4434ca072b789aaeac75a8903d9afed112 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_amax.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_amax.hpp index 7594dde74dba43a2e13cb8317754563bb04efdeb..810249994f00d95cb00352eb28ba464f6ecd7856 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_amax.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_amax.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_avg.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_avg.hpp index 3272e7f9af35f50dd5cae8eb200b6f70f5251678..6640ffec700bba666962436a63b62ab828cfc719 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_avg.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_avg.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_max.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_max.hpp index 519ec8271dd82b83597ddff695ea00588dfbd0b7..441bea275c164c60ff9e73af2a30c3a021be0090 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_max.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_max.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_min.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_min.hpp index 77b2fb9306bebc9cd061ec94e3f074aa5aaeb897..05912e1e960f33cef0d82f56798326e3dc65ea2b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_min.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_min.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_norm2.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_norm2.hpp index 5abb5c5eecad17c988c5b5d21189c29dfd50c0ee..b0e16cfeaaef2866134c1dfc5d2a22908404a499 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_norm2.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_norm2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_add.hpp index 23bd988b8a5d63d5aa284514459a206c7a38bf4e..0ee0f7d0fe156e6fb9a618e0bc66d2a9f4e2e7c4 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_avg.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_avg.hpp index 7ce5577d7faf64a44569e9e97d3c7143378ab852..feb9d99eba11da1c6b2cddadffd80d3fdb4cb7c2 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_avg.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_avg.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_norm2.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_norm2.hpp index 7e4c5b77f0855ec5fcc7e745b5ad72071b8607d5..a9f0d77cb885b3be117c0621357d12ecac0c599b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_norm2.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_norm2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_add.hpp index 5eca5fea7f83e7a9616489af2ee3a3b5fd845b45..bd1f72250a36137736e7bdd3dc5a11d842103539 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_amax.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_amax.hpp index b0e98411bf5ca8367283624ebe068785668ac227..60189fd3047b77a3c4ef4f34be9a75ab69dc73ac 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_amax.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_amax.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_avg.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_avg.hpp index 84609a995dd9f00c8686322285842b030298437f..b5aa9ce613228d41a89c2787b8bf275f429fcba0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_avg.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_avg.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_max.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_max.hpp index 2f816bb11dcb7cd3d97311ccce038531271db282..0e20a17ff98e593f93c91861b47449ad6007a6aa 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_max.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_max.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_min.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_min.hpp index 9cecd4a5b47afe8d0595a6aaec13b39042ed4c32..c8c1cd7e802022268d323582b7955a366ee743ef 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_min.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_min.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_norm2.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_norm2.hpp index 42e9b7fc792b70ce5ee85e22cce577f7a02d2ac6..3e71b44673994b1a0fd460191cc29c36b7ac6b97 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_norm2.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_norm2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_add.hpp index 494f1c3d71809f4a1302a1933e26cf89a754ad35..c8286e56d19377ca1c4a9b4fa9a714687a0b5fcb 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_avg.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_avg.hpp index a80abb9247d595fda58e5d1c2a461bc009831b79..36fbc52f113c43a71d6cc0c647b01a16f2a16c4f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_avg.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_avg.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_amax.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_amax.hpp index 53fd286383bf91fdaae001d34be26e3bafee8d58..0900cffa353f778e8ce544ee7efbbc1ec9e8fd0a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_amax.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_amax.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_max.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_max.hpp index df5a4db484953c0c7bb00d6078dc701b681d3950..726b56066b000521a697989b4f5eb807a998aeba 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_max.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_max.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_min.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_min.hpp index ed78acd926882844a94e13ce6bf0d6431380efa5..0ae9838607d53897431229488e177229d88561b8 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_min.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_min.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/reduce.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/reduce.hpp index 0038fc26da53d55648a4374af3a02cde6d86184e..a3a39b7ccc7216b371a540cb8eaebee075cbe4d6 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/reduce.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/reduce.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax.hpp index 36eb092f0f0707b266ef9a1d7c3b5248947c1e87..26815f1447ce241bdd53003bb43d3a9df25b684e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -9,34 +9,33 @@ #include "ck/ck.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/tensor_operation/gpu/device/device_softmax.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -void add_device_softmax_f16_f16_rank3_instances( - std::vector>&); -void add_device_softmax_f16_f16_rank4_instances( - std::vector>&); - -void add_device_softmax_f32_f32_rank3_instances( - std::vector>&); -void add_device_softmax_f32_f32_rank4_instances( - std::vector>&); - -void add_device_softmax_i8_i8_rank3_instances( - std::vector>&); -void add_device_softmax_i8_i8_rank4_instances( - std::vector>&); - -template -struct DeviceOperationInstanceFactory< - ck::tensor_operation::device:: - DeviceSoftmax> +template +struct DeviceOperationInstanceFactory> { - using DeviceOp = - DeviceSoftmax; + using DeviceOp = DeviceSoftmax; static auto GetInstances() { @@ -46,25 +45,49 @@ struct DeviceOperationInstanceFactory< std::is_same_v) { if constexpr(Rank == 3) - add_device_softmax_f16_f16_rank3_instances(op_ptrs); + { + if constexpr(NumReduceDim == 1) + add_device_softmax_f16_f16_rank3_reduce1_instances(op_ptrs); + else if constexpr(NumReduceDim == 2) + add_device_softmax_f16_f16_rank3_reduce2_instances(op_ptrs); + else if constexpr(NumReduceDim == 3) + add_device_softmax_f16_f16_rank3_reduce3_instances(op_ptrs); + } else if constexpr(Rank == 4) - add_device_softmax_f16_f16_rank4_instances(op_ptrs); + { + if constexpr(NumReduceDim == 1) + add_device_softmax_f16_f16_rank4_reduce1_instances(op_ptrs); + else if constexpr(NumReduceDim == 2) + add_device_softmax_f16_f16_rank4_reduce2_instances(op_ptrs); + else if constexpr(NumReduceDim == 3) + add_device_softmax_f16_f16_rank4_reduce3_instances(op_ptrs); + else if constexpr(NumReduceDim == 4) + add_device_softmax_f16_f16_rank4_reduce4_instances(op_ptrs); + } } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { if constexpr(Rank == 3) - add_device_softmax_f32_f32_rank3_instances(op_ptrs); - else if constexpr(Rank == 4) - add_device_softmax_f32_f32_rank4_instances(op_ptrs); - } - else if constexpr(std::is_same_v && std::is_same_v && - std::is_same_v) - { - if constexpr(Rank == 3) - add_device_softmax_i8_i8_rank3_instances(op_ptrs); + { + if constexpr(NumReduceDim == 1) + add_device_softmax_f32_f32_rank3_reduce1_instances(op_ptrs); + else if constexpr(NumReduceDim == 2) + add_device_softmax_f32_f32_rank3_reduce2_instances(op_ptrs); + else if constexpr(NumReduceDim == 3) + add_device_softmax_f32_f32_rank3_reduce3_instances(op_ptrs); + } else if constexpr(Rank == 4) - add_device_softmax_i8_i8_rank4_instances(op_ptrs); + { + if constexpr(NumReduceDim == 1) + add_device_softmax_f32_f32_rank4_reduce1_instances(op_ptrs); + else if constexpr(NumReduceDim == 2) + add_device_softmax_f32_f32_rank4_reduce2_instances(op_ptrs); + else if constexpr(NumReduceDim == 3) + add_device_softmax_f32_f32_rank4_reduce3_instances(op_ptrs); + else if constexpr(NumReduceDim == 4) + add_device_softmax_f32_f32_rank4_reduce4_instances(op_ptrs); + } } return op_ptrs; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance.hpp deleted file mode 100644 index 83f52fc3ee7a9547161ce1284c2844976819c0ec..0000000000000000000000000000000000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/device_softmax.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_f16_f16_rank3_instances( - std::vector>& instances); -void add_device_softmax_f16_f16_rank4_instances( - std::vector>& instances); - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.hpp index 046ff5780556331a6afc697f87dbbf02dcf1ff25..3fd2bd089ed8fdffa2d40c7b6b10ae146322ee6b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f16_f16_rank3_reduce1_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.hpp index 8e6a226f6a13fb1dcd7d61b56b476f38072b6a98..210fdc0a58548f8388602cb960ae0b3c019e05ba 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f16_f16_rank3_reduce2_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.hpp index 518fa5f98679099189f15dfe5a57e3ac0324e3fa..894fb034d0dfa86ec85d76b7f015662af1d775d1 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f16_f16_rank3_reduce3_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.hpp index 10016cdd707cbcb46282e540e15e6b79a67c6a78..708ef0ce130e1095c4c88b520d1dbdbcabf23685 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f16_f16_rank4_reduce1_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.hpp index cdd5a3cd7b6fb6806db3de0a10b7ed87b29c5c2d..6754e5ceffa02d4bcad7919c7ea25b0dd85c0cc4 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f16_f16_rank4_reduce2_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.hpp index a8be272e02014495cb42d0509b50075f2d7da707..5e111176e198d77b8e0e426f2a578cd9c3dd2b3f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f16_f16_rank4_reduce3_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.hpp index ec8296ff22fb6591152a3dcfad63c250c70cabb8..a3cecb32f83f707f54a53b7ed107f2662f4a2f08 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f16_f16_rank4_reduce4_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_type.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_type.hpp index b3877c4bb3f6217427f69c07f7073751b254c938..8c0782daa556168ac359f68d301080720c622e1b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_type.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_type.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -16,7 +16,6 @@ template using device_softmax_f16_f16_instances = std::tuple< // clang-format off // InDataType, AccDataType, OutDataType, InElementwiseOp, AccElementwiseOp, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize> - // fallback kernel DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1>, DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 8, 1, 8, 8>, DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 4, 64, 1, 8, 1, 8, 8>, @@ -33,6 +32,13 @@ using device_softmax_f16_f16_instances = std::tuple< // clang-format on >; +template +using device_softmax_f16_f16_generic_instance = std::tuple< + // clang-format off + DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 64, 8, 8, 1, 1, 1, 1, 1> + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance.hpp deleted file mode 100644 index a6d9a359f4622ec9534ab090d46e363bd9101177..0000000000000000000000000000000000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/device_softmax.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_f32_f32_rank3_instances( - std::vector>& instances); -void add_device_softmax_f32_f32_rank4_instances( - std::vector>& instances); - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.hpp index 6621a2c867ac90402a853e5c2065663a96220927..4cc469025335831b8c502e0693e92894ba454202 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f32_f32_rank3_reduce1_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.hpp index 3dfac98ed8beeee6060b6ba3f6b9b1ff19f660fe..65724d7888ac9801cc2dcd97ce56dd8681ca6cfe 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f32_f32_rank3_reduce2_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.hpp index 6d2a0c932500e874641d8cef9fd65b4d280c89b6..13bd45598eca17c81df5b8eba6fe41f3400057b5 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f32_f32_rank3_reduce3_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.hpp index 97dd3dcb18aed5c228313a961994096445948bbf..d58b424ee94ffad33440525e190c7adf45598eb2 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f32_f32_rank4_reduce1_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.hpp index 58f8760acccd1b571df40ea9a08946c1f5f9764f..378e45eeb783d85a6da303349c7bd7b0329f7fc1 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f32_f32_rank4_reduce2_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.hpp index df8d31f0da76b0f027b4ee4818006dd331e38454..293df08c7e9e8edd9a419b7e94a5d83fa46efe8f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f32_f32_rank4_reduce3_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.hpp index 1bd773227e172278a466d7650e90fd39ec3e56f8..e503a9fec1f5c13e71b297d985c3de8c15497758 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f32_f32_rank4_reduce4_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_type.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_type.hpp index 16f129d2d07c19a926c75f1e878e8db3fb9bf821..90c5ddc8a01b7fd0ef58a8f744133a0635e2fe29 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_type.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_type.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -16,7 +16,7 @@ template using device_softmax_f32_f32_instances = std::tuple< // clang-format off // InDataType, AccDataType, OutDataType, InElementwiseOp, AccElementwiseOp, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize> - DeviceSoftmaxImpl< F32, F32, F32, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1>, // fallback kernel + DeviceSoftmaxImpl< F32, F32, F32, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1>, DeviceSoftmaxImpl< F32, F32, F32, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 8, 1, 4, 4>, DeviceSoftmaxImpl< F32, F32, F32, PassThrough, PassThrough, Rank, Reduce, 256, 4, 64, 1, 8, 1, 4, 4>, DeviceSoftmaxImpl< F32, F32, F32, PassThrough, PassThrough, Rank, Reduce, 256, 2, 128, 1, 8, 1, 4, 4>, @@ -32,6 +32,13 @@ using device_softmax_f32_f32_instances = std::tuple< // clang-format on >; +template +using device_softmax_f32_f32_generic_instance = std::tuple< + // clang-format off + DeviceSoftmaxImpl< F32, F32, F32, PassThrough, PassThrough, Rank, Reduce, 64, 8, 8, 1, 1, 1, 1, 1> + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance.hpp deleted file mode 100644 index f80f712ff5e501d6516d9c7210a59276a96b6541..0000000000000000000000000000000000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/device_softmax.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_i8_i8_rank3_instances( - std::vector>& instances); -void add_device_softmax_i8_i8_rank4_instances( - std::vector>& instances); - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce1.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce1.hpp deleted file mode 100644 index 6f9952e7d58a448cfea0e3c6ba84af24e9de4e52..0000000000000000000000000000000000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce1.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/device_softmax.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_i8_i8_rank3_reduce1_instances( - std::vector>& instances); - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce2.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce2.hpp deleted file mode 100644 index 2cbd13a1ba5226316363cfb07dfd23adfac4aa9f..0000000000000000000000000000000000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce2.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/device_softmax.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_i8_i8_rank3_reduce2_instances( - std::vector>& instances); - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce3.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce3.hpp deleted file mode 100644 index 7b12522a85987f85e83a6f8afb34b97844507c49..0000000000000000000000000000000000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce3.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/device_softmax.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_i8_i8_rank3_reduce3_instances( - std::vector>& instances); - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce1.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce1.hpp deleted file mode 100644 index 54d477f80c5b222cd5a76430dae3bbd4db81068c..0000000000000000000000000000000000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce1.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/device_softmax.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_i8_i8_rank4_reduce1_instances( - std::vector>& instances); - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce2.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce2.hpp deleted file mode 100644 index 4ffc44e3a9217e689297bceb81bd6a38c41cd1c3..0000000000000000000000000000000000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce2.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/device_softmax.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_i8_i8_rank4_reduce2_instances( - std::vector>& instances); - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce3.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce3.hpp deleted file mode 100644 index 08cbb81272f9ff5e7ce9a9fdd9075d81f1ed7211..0000000000000000000000000000000000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce3.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/device_softmax.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_i8_i8_rank4_reduce3_instances( - std::vector>& instances); - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce4.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce4.hpp deleted file mode 100644 index 187d034b95ac506426da4fa1b8240b61e2517e46..0000000000000000000000000000000000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce4.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/device_softmax.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_i8_i8_rank4_reduce4_instances( - std::vector>& instances); - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp deleted file mode 100644 index 7fc9ed69198f1e7bf447a6baa9dedb3286440a94..0000000000000000000000000000000000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -template -using device_softmax_i8_i8_instances = std::tuple< - // clang-format off - // InDataType, AccDataType, OutDataType, InElementwiseOp, AccElementwiseOp, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize> - // fallback kernel - DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 16, 1, 1, 1>, - DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 16, 1, 16, 16>, - DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 4, 64, 1, 16, 1, 16, 16>, - DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 2, 128, 1, 16, 1, 16, 16>, - DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 2, 128, 1, 32, 1, 16, 16>, - DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 2, 128, 1, 64, 1, 16, 16>, - DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 1, 256, 1, 16, 1, 16, 16>, - DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 1, 256, 1, 32, 1, 16, 16>, - DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 1, 256, 1, 64, 1, 16, 16>, - // Reduction on middle dimensions - // InSrcVectorDim is 0 since we want to coalesce reads on M dimension - DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 8, 8, 0, 1, 1>, - DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 32, 8, 32, 8, 0, 16, 8> - // clang-format on - >; - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_instance.hpp index 03be6e2bc7c15ee350569f7ed97ebf2d3ad3b515..10f99acb8d6fe4c5d1fb0a065231995a3d870392 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_instance.hpp @@ -1,8 +1,19 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.hpp" diff --git a/library/include/ck/library/utility/algorithm.hpp b/library/include/ck/library/utility/algorithm.hpp index 86f04dd362397e774c9821a16a7c412dea000937..57136f8a2a19fbdaa2e0276699c786c33cc61473 100644 --- a/library/include/ck/library/utility/algorithm.hpp +++ b/library/include/ck/library/utility/algorithm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/utility/check_err.hpp b/library/include/ck/library/utility/check_err.hpp index a89d03d324f33f0be95f88c328a9d7aee38e7460..7f63a81a020617a00bd91ab529add56dd2fc55a9 100644 --- a/library/include/ck/library/utility/check_err.hpp +++ b/library/include/ck/library/utility/check_err.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/utility/conv_common.hpp b/library/include/ck/library/utility/conv_common.hpp index 6fad9f7d77d349853e63512ae486a57e7ceb7b6f..085454f42dee006cf7c48ebd18804df92fa36a37 100644 --- a/library/include/ck/library/utility/conv_common.hpp +++ b/library/include/ck/library/utility/conv_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp b/library/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp index 2b4f63b28b8d2f1e59cef67b99a577314faf4ae2..ff697fb71c0340df3f3aac6953ce76cfdef4b278 100644 --- a/library/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp +++ b/library/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/utility/convolution_parameter.hpp b/library/include/ck/library/utility/convolution_parameter.hpp index f4a2b56f75a6b40a476080e74d386fab524286a4..df6efca10808beb96535ef69a3ba6116a449b9e3 100644 --- a/library/include/ck/library/utility/convolution_parameter.hpp +++ b/library/include/ck/library/utility/convolution_parameter.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/utility/device_memory.hpp b/library/include/ck/library/utility/device_memory.hpp index 87940e1671ab71b0e418c82fbec1f95d6627a8b5..1c16ff59165af029e2f81d8364daf07f844daaac 100644 --- a/library/include/ck/library/utility/device_memory.hpp +++ b/library/include/ck/library/utility/device_memory.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/utility/fill.hpp b/library/include/ck/library/utility/fill.hpp index c0bc3727641041bbd1d9fb0c5e7456c5042b6448..c01e139ea0b8f160f790bde4270640c261affa6e 100644 --- a/library/include/ck/library/utility/fill.hpp +++ b/library/include/ck/library/utility/fill.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/utility/host_common_util.hpp b/library/include/ck/library/utility/host_common_util.hpp index 6f4466e8da0e2b77c1ad981cfb6f730a74a92d71..20a8f234dbbf9f30582ed371c5f1a90cce2b3196 100644 --- a/library/include/ck/library/utility/host_common_util.hpp +++ b/library/include/ck/library/utility/host_common_util.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/utility/host_conv.hpp b/library/include/ck/library/utility/host_conv.hpp deleted file mode 100644 index 8348a3089f4606563caa5582ac427154308e7fa5..0000000000000000000000000000000000000000 --- a/library/include/ck/library/utility/host_conv.hpp +++ /dev/null @@ -1,152 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once -#include "host_tensor.hpp" -#include "conv_common.hpp" - -template -void host_conv_nchw_kcyx_nkhw(const Tensor& in, - const Tensor& wei, - Tensor& out, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads&) -{ - constexpr auto I0 = ck::Number<0>{}; - constexpr auto I1 = ck::Number<1>{}; - - auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { - float v = 0; - for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c) - { - for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) - { - int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; - for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x) - { - int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; - if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && - wi < in.mDesc.GetLengths()[3]) - { - v += ck::type_convert(in(n, c, hi, wi)) * - ck::type_convert(wei(k, c, y, x)); - } - } - } - } - out(n, k, ho, wo) = ck::type_convert(v); - }; - - make_ParallelTensorFunctor(f_nchw, - out.mDesc.GetLengths()[0], - out.mDesc.GetLengths()[1], - out.mDesc.GetLengths()[2], - out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); -} - -template -void host_conv3d_ndhwc_kzyxc_ndhwk(const Tensor& in, - const Tensor& wei, - Tensor& out, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads&) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - const auto Di = in.mDesc.GetLengths()[1]; - const auto Hi = in.mDesc.GetLengths()[2]; - const auto Wi = in.mDesc.GetLengths()[3]; - const auto Z = wei.mDesc.GetLengths()[1]; - const auto Y = wei.mDesc.GetLengths()[2]; - const auto X = wei.mDesc.GetLengths()[3]; - const auto C = wei.mDesc.GetLengths()[4]; - - auto f_ndhwc = [&](auto n, auto do_tmp, auto ho_tmp, auto wo_tmp, auto k) { - // do__ must be converted to signed integer, otherwise zmin might be wrong in cases - // negative values. - const int do_ = static_cast(do_tmp); - const int ho = static_cast(ho_tmp); - const int wo = static_cast(wo_tmp); - const int zmin = - std::max(0, - (in_left_pads[I0] - do_ * conv_strides[I0] + conv_dilations[I0] - 1) / - conv_dilations[I0]); - const int ymin = - std::max(0, - (in_left_pads[I1] - ho * conv_strides[I1] + conv_dilations[I1] - 1) / - conv_dilations[I1]); - const int xmin = - std::max(0, - (in_left_pads[I2] - wo * conv_strides[I2] + conv_dilations[I2] - 1) / - conv_dilations[I2]); - const int zmax = - std::min(Z, (in_left_pads[I0] - do_ * conv_strides[I0] + Di) / conv_dilations[I0]); - const int ymax = - std::min(Y, (in_left_pads[I1] - ho * conv_strides[I1] + Hi) / conv_dilations[I1]); - const int xmax = - std::min(X, (in_left_pads[I2] - wo * conv_strides[I2] + Wi) / conv_dilations[I2]); - const int di_min = do_ * conv_strides[I0] + zmin * conv_dilations[I0] - in_left_pads[I0]; - const int hi_min = ho * conv_strides[I1] + ymin * conv_dilations[I1] - in_left_pads[I1]; - const int wi_min = wo * conv_strides[I2] + xmin * conv_dilations[I2] - in_left_pads[I2]; - - double v = 0; - - const TIn* in_n = in.mData.data() + n * Di * Hi * Wi * C; - const TWei* wei_k = wei.mData.data() + k * Z * Y * X * C; - - int di = di_min; - for(int z = zmin; z < zmax; ++z, di += conv_dilations[I0]) - { - const TIn* in_n_di = in_n + di * Hi * Wi * C; - const TWei* wei_k_z = wei_k + z * Y * X * C; - int hi = hi_min; - - for(int y = ymin; y < ymax; ++y, hi += conv_dilations[I1]) - { - const TIn* in_n_di_hi = in_n_di + hi * Wi * C; - const TWei* wei_k_z_y = wei_k_z + y * X * C; - int wi = wi_min; - - for(int x = xmin; x < xmax; ++x, wi += conv_dilations[I2]) - { - const TIn* in_n_di_hi_wi = in_n_di_hi + wi * C; - const TWei* wei_k_z_y_x = wei_k_z_y + x * C; - - for(int c = 0; c < C; ++c) - { - v += static_cast(in_n_di_hi_wi[c]) * - static_cast(wei_k_z_y_x[c]); - } - } - } - } - - out(n, do_, ho, wo, k) = v; - }; - - make_ParallelTensorFunctor(f_ndhwc, - out.mDesc.GetLengths()[0], - out.mDesc.GetLengths()[1], - out.mDesc.GetLengths()[2], - out.mDesc.GetLengths()[3], - out.mDesc.GetLengths()[4])(std::thread::hardware_concurrency() - 4); -} diff --git a/library/include/ck/library/utility/host_gemm.hpp b/library/include/ck/library/utility/host_gemm.hpp index 44036d0234375824ed3aa773ffe71de02b84c570..5eb7e3b8c9745c48390482622d361a2ddd491159 100644 --- a/library/include/ck/library/utility/host_gemm.hpp +++ b/library/include/ck/library/utility/host_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/utility/host_tensor.hpp b/library/include/ck/library/utility/host_tensor.hpp index 29d94b0036c303d4cc0053803f73eab3b71c0791..816d8341308d27837f7b8da1b5af19e88f223665 100644 --- a/library/include/ck/library/utility/host_tensor.hpp +++ b/library/include/ck/library/utility/host_tensor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -13,6 +13,7 @@ #include "ck/utility/data_type.hpp" #include "ck/utility/span.hpp" +#include "ck/utility/type_convert.hpp" #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/ranges.hpp" @@ -411,6 +412,12 @@ struct Tensor } } + template + std::size_t GetOffsetFromMultiIndex(Is... is) const + { + return mDesc.GetOffsetFromMultiIndex(is...); + } + template T& operator()(Is... is) { diff --git a/library/include/ck/library/utility/host_tensor_generator.hpp b/library/include/ck/library/utility/host_tensor_generator.hpp index 4259862e65e331b30377997c512459e2a21e9383..31ff13aec254a166c6439fdef3fd17907017e147 100644 --- a/library/include/ck/library/utility/host_tensor_generator.hpp +++ b/library/include/ck/library/utility/host_tensor_generator.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/utility/iterator.hpp b/library/include/ck/library/utility/iterator.hpp index 9fdc88ea76825e2a1cdf659f95fa2928038867aa..b44e2d8e3c4a6f8b4423b03699539ade2fbfb0d7 100644 --- a/library/include/ck/library/utility/iterator.hpp +++ b/library/include/ck/library/utility/iterator.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/utility/literals.hpp b/library/include/ck/library/utility/literals.hpp index a73a2ea054150665557b8e1da473f2cb9877526b..a8bd6303f12f0a36bafb36c1d7f1453eab13392b 100644 --- a/library/include/ck/library/utility/literals.hpp +++ b/library/include/ck/library/utility/literals.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/utility/numeric.hpp b/library/include/ck/library/utility/numeric.hpp index 70a7e87ab1c8d94569e705b3f6832c78bac05853..9ee118d4757109416bac18cb54229e93de59b75c 100644 --- a/library/include/ck/library/utility/numeric.hpp +++ b/library/include/ck/library/utility/numeric.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/utility/op_instance_engine.hpp b/library/include/ck/library/utility/op_instance_engine.hpp deleted file mode 100644 index 78812e8c81d600ab1258400205b98994ab120d3a..0000000000000000000000000000000000000000 --- a/library/include/ck/library/utility/op_instance_engine.hpp +++ /dev/null @@ -1,249 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "ck/utility/functional2.hpp" -#include "ck/tensor_operation/gpu/device/device_base.hpp" - -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" - -namespace ck { -namespace utils { - -struct ProfileBestConfig -{ - std::string best_op_name; - float best_avg_time = std::numeric_limits::max(); - float best_tflops = std::numeric_limits::max(); - float best_gb_per_sec = std::numeric_limits::max(); -}; - -/** - * @brief This class describes an operation instance(s). - * - * Op instance defines a particular specializations of operator - * template. Thanks to this specific input/output data types, data - * layouts and modifying elementwise operations it is able to create - * it's input/output tensors, provide pointers to instances which - * can execute it and all operation specific parameters. - */ -template -class OpInstance -{ - public: - template - using TensorPtr = std::unique_ptr>; - using InTensorsTuple = std::tuple...>; - using DeviceMemPtr = std::unique_ptr; - using DeviceBuffers = std::vector; - - OpInstance() = default; - OpInstance(const OpInstance&) = default; - OpInstance& operator=(const OpInstance&) = default; - virtual ~OpInstance(){}; - - virtual InTensorsTuple GetInputTensors() const = 0; - virtual TensorPtr GetOutputTensor() const = 0; - virtual std::unique_ptr - MakeInvokerPointer(tensor_operation::device::BaseOperator*) const = 0; - virtual std::unique_ptr - MakeArgumentPointer(tensor_operation::device::BaseOperator*, - const DeviceBuffers&, - const DeviceMemPtr&) const = 0; - virtual std::size_t GetFlops() const = 0; - virtual std::size_t GetBtype() const = 0; -}; - -/** - * @brief A generic operation instance run engine. - */ -template -class OpInstanceRunEngine -{ - public: - using OpInstanceT = OpInstance; - template - using TensorPtr = std::unique_ptr>; - using DeviceMemPtr = std::unique_ptr; - using InTensorsTuple = std::tuple...>; - using DeviceBuffers = std::vector; - using InArgsTypesTuple = std::tuple; - - OpInstanceRunEngine() = delete; - - template > - OpInstanceRunEngine(const OpInstanceT& op_instance, - const ReferenceOp& reference_op = ReferenceOp{}, - bool do_verification = true) - : op_instance_{op_instance} - { - in_tensors_ = op_instance_.GetInputTensors(); - out_tensor_ = op_instance_.GetOutputTensor(); - - if constexpr(std::is_invocable_v&..., - Tensor&>) - { - if(do_verification) - { - ref_output_ = op_instance_.GetOutputTensor(); - CallRefOpUnpackArgs(reference_op, std::make_index_sequence{}); - } - } - AllocateDeviceInputTensors(std::make_index_sequence{}); - out_device_buffer_ = std::make_unique(sizeof(OutDataType) * - out_tensor_->mDesc.GetElementSpaceSize()); - out_device_buffer_->SetZero(); - } - - virtual ~OpInstanceRunEngine(){}; - - template - bool Test(const std::vector& op_ptrs) - { - bool res{true}; - for(auto& op_ptr : op_ptrs) - { - auto invoker = op_instance_.MakeInvokerPointer(op_ptr.get()); - auto argument = op_instance_.MakeArgumentPointer( - op_ptr.get(), in_device_buffers_, out_device_buffer_); - if(op_ptr->IsSupportedArgument(argument.get())) - { - std::cout << "Testing instance: " << op_ptr->GetTypeString() << std::endl; - invoker->Run(argument.get()); - out_device_buffer_->FromDevice(out_tensor_->mData.data()); - if(!ref_output_) - { - throw std::runtime_error( - "OpInstanceRunEngine::Test: Reference value not availabe." - " You have to provide reference function."); - } - // TODO: enable flexible use of custom check_error functions - bool inst_res = CheckErr(out_tensor_->mData, ref_output_->mData); - std::cout << (inst_res ? "SUCCESS" : "FAILURE") << std::endl; - res = res && inst_res; - out_device_buffer_->SetZero(); - } - else - { - std::cout << "Given conv problem is not supported by instance: \n\t>>>>" - << op_ptr->GetTypeString() << std::endl; - } - } - return res; - } - - template - ProfileBestConfig Profile(const std::vector& op_ptrs, - bool time_kernel = false, - bool do_verification = false, - bool do_log = false) - { - ProfileBestConfig best_config; - - for(auto& op_ptr : op_ptrs) - { - auto invoker = op_instance_.MakeInvokerPointer(op_ptr.get()); - auto argument = op_instance_.MakeArgumentPointer( - op_ptr.get(), in_device_buffers_, out_device_buffer_); - if(op_ptr->IsSupportedArgument(argument.get())) - { - std::string op_name = op_ptr->GetTypeString(); - float avg_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); - - std::size_t flops = op_instance_.GetFlops(); - std::size_t num_btype = op_instance_.GetBtype(); - float tflops = static_cast(flops) / 1.E9 / avg_time; - float gb_per_sec = num_btype / 1.E6 / avg_time; - - std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s, " << op_name << std::endl; - - if(avg_time < best_config.best_avg_time) - { - best_config.best_op_name = op_name; - best_config.best_tflops = tflops; - best_config.best_gb_per_sec = gb_per_sec; - best_config.best_avg_time = avg_time; - } - - if(do_verification) - { - out_device_buffer_->FromDevice(out_tensor_->mData.data()); - if(!ref_output_) - { - throw std::runtime_error( - "OpInstanceRunEngine::Profile: Reference value not availabe." - " You have to provide reference function."); - } - // TODO: enable flexible use of custom check_error functions - CheckErr(out_tensor_->mData, ref_output_->mData); - - if(do_log) {} - } - out_device_buffer_->SetZero(); - } - } - return best_config; - } - - void SetAtol(double a) { atol_ = a; } - void SetRtol(double r) { rtol_ = r; } - - private: - template - void CallRefOpUnpackArgs(const F& f, std::index_sequence) const - { - f(*std::get(in_tensors_)..., *ref_output_); - } - - template - void AllocateDeviceInputTensors(std::index_sequence) - { - (AllocateDeviceInputTensorsImpl(), ...); - } - - template - void AllocateDeviceInputTensorsImpl() - { - const auto& ts = std::get(in_tensors_); - in_device_buffers_ - .emplace_back( - std::make_unique(sizeof(std::tuple_element_t) * - ts->mDesc.GetElementSpaceSize())) - ->ToDevice(ts->mData.data()); - } - - static constexpr std::size_t kNInArgs_ = std::tuple_size_v; - const OpInstanceT& op_instance_; - double rtol_{1e-5}; - double atol_{1e-8}; - - InTensorsTuple in_tensors_; - TensorPtr out_tensor_; - TensorPtr ref_output_; - - DeviceBuffers in_device_buffers_; - DeviceMemPtr out_device_buffer_; - - template - bool CheckErr(const std::vector& dev_out, const std::vector& ref_out) const - { - return ck::utils::check_err(dev_out, ref_out, "Error: incorrect results!", rtol_, atol_); - } -}; - -} // namespace utils -} // namespace ck diff --git a/library/include/ck/library/utility/ranges.hpp b/library/include/ck/library/utility/ranges.hpp index 55c322f1ace041b9c4f346a5df4250c2323759dd..f11e4204a0cec5378328f69d3bb553a6227c47a3 100644 --- a/library/include/ck/library/utility/ranges.hpp +++ b/library/include/ck/library/utility/ranges.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp index cc8787458d7ffff072cf8c055c8def43ca487e51..e730e9f589faf8047dd60a218ef191eea510cb63 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp index 04200cfb52e26f227d89852e2845547e6624c22d..f6696ffa9f6ca2c0665ecde3f0394c0802588a46 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp index 7b86f3cc728fdac1117cbe927445b6830e81e9ab..32d6f258b5c9152205f8b13e0326739b567b0845 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp index 2afb1afbc3d2355f682af9cad21b62c0ab1ee8ea..ee246ba56edd7271d7d2f980c62003a7ab7c7f06 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp index 68d768949c859f5ca8aedfebfa888253dc6c276c..5a9483b3099d9985f88f0121135a79d55edefd5a 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp index 737e5bfca3bccbafcd8179d80bb43b58bd545639..0fa0719237b000b7ab9b13d94cfa951edacf3139 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp index e09d01736083f54cf2c1b848814b2d7fb683d3b8..2f42f62b0b14b91bf94208b27684ba4b0b10c5d2 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp index 984d66e2884273e006772530e6a978891766a8e0..10b4cea7d7cddd4e694c0de42d122bb7ff3bd505 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp index 12cada9c44d61d8b81ad2dfa4877dc57b26ed681..c687eb20f06918b41ca54fa72a2a577b809eb2cf 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp index 13f198862e5583b1f63c51efaab11fa733976cfe..b19374ca6d5635238da838808d863b19fd1b6b93 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp index 2ca1adc2f6d160a97a2f40b4075f07b185e68351..bbd318ba9299005910d63a83d6906f4b69571479 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp index fe5de52796f6c9a0c94a84d883a7b1a714c4a652..187ccb5ff62fab57a83139f340a4032fd03e8dea 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp index 5b55c8e15e01d879cc11e57128c3e2dfdf81890f..ec2b2646ffa1a78b058a90b0508eba7f1d65a2be 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp index 9517e4577e5b9af188aaf07413e8c0b3cf6c32ee..d76cd350c81807f330f7b04ae2ed2d1bf6bb5080 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp index 43b91244060ed90c9f2558dc99279ce0a086845b..ef65106c2cf537ae29dcc7313829ecaf3ae0ed58 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp index 326500fcbf013e8540fff9252202ac7af33b2938..078b241f9dfe194778806824911d791c7dd7a2de 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp index e1bfa88f49d4fb886f953661c03b8638516ba646..4db05589b153b2307d377e67147b5ee1bf8336bc 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp index f59b742534b2c234b2b4908f5e540f591b104619..e25f903a8272ceb6bd559b8edfdb72f7ac759716 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp index 04a748f45513c56ce23781dd08bcf4c2bae03331..a0afaabbc7bbff785c7a3e6cba884a0bb1f768e4 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp index 9b96194c874c3daea7c9c4eb931368ce5645b50e..67dfc4cd35066053c7b9936a76c402454cf1e99b 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp index 0713dfcd99ca183d8f64cd23ac6067be990b39e7..9001c901cfbecb473c6a0e4f70a9db6526e2a4cb 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..fda55a930363b43a91235d39c3b37cb028fdc380 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/CMakeLists.txt @@ -0,0 +1,18 @@ +add_instance_library(device_batched_gemm_multi_d_instance + device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp + device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp + device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp + device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp + device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp + device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp + device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp + device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp + device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp + device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp + device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp + device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp + device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp + device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp + device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp + device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp +) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3fe9f78b2a5b00fcfd563f7dc2b57e74699cc50f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=128, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<4, 2>, S<8, 2>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<2, 4>, S<2, 8>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=16, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<1, 4>, S<1, 4>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4ab22bb03aee81dcd4656478cefc415a7c3fc212 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..80c890cdb34263216578c811a115e6eae402f6d9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=128, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<4, 2>, S<8, 2>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<2, 4>, S<2, 8>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=16, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<1, 4>, S<1, 4>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..647c583036c71cdbe56a8a057309d5c5ff4f2bfd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3ce582f1f984ec67bf0a50076e73b687ff06a59a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=128, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<8, 2>, S<4, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<2, 8>, S<2, 4>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<2, 4>, S<2, 8>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<8, 1>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=16, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<1, 4>, S<1, 4>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..34c29d0380150e8d2c8b01e9cd806ffbf37258a4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e0da8e3d5a25e591e2fa096a9ead9b6d5118ebb9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // // MPerBlock=128, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<8, 2>, S<4, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<2, 8>, S<2, 4>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // // MPerBlock=64, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<2, 4>, S<2, 8>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<8, 1>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=16, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<1, 4>, S<1, 4>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3edd3f78c661d5a31fd8e61102b7ed51ea386bcd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..33234edc1757fffc5e2d093857b1c459afa121f6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=128, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=32, NPerBlock=32 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=16, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 4, 1, 4, 1, S<4, 2>, S<4, 2>, S<1, 1, 4, 4>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 4, 1, 4, 1, S<2, 4>, S<2, 4>, S<1, 1, 4, 4>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 4, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 4, 4, 1, 1, S<2, 4>, S<2, 4>, S<4, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 4, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 4, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 4, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 4, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 4, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 4, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 4, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 4, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 4, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..16107e1bded3609eed81b8a3c1b5d3ed001decbf --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<4, 4>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=128, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=32, NPerBlock=32 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 2, 4, 1, S<4, 2>, S<2, 2>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 4, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 4, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 4, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 4, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 4, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 4, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 4, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 4, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3e4bdb0172bfaad8abcdb133bebac3d3723bd257 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=128, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=32, NPerBlock=32 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=16, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<2, 4>, S<2, 4>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..81077171621c6832453d601ed1309c9be4c0b29b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<4, 4>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=128, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=32, NPerBlock=32 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 2, 4, 1, S<4, 2>, S<2, 2>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0e943c88c30b4d5b0b6cf4576348fc5166f7bd97 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=128, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=32, NPerBlock=32 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=16, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<2, 4>, S<2, 4>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ea5e7c562d43ee3b411e64d53b532ac4f0b8fd36 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<4, 4>, S<4, 2>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=128, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=32, NPerBlock=32 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 2, 4, 1, S<4, 2>, S<2, 2>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..000a4b0130f4e31932fdccae84e8b4b8a0d704cb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // // MPerBlock=128, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // // MPerBlock=64, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=32, NPerBlock=32 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=16, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<2, 4>, S<2, 4>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..24fb67619aa03fa6e109d6091ba750c30058743d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<4, 4>, S<4, 2>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // // MPerBlock=128, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // // MPerBlock=64, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=32, NPerBlock=32 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 2, 4, 1, S<4, 2>, S<2, 2>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp index 521c3d9219b1cd23e30ca194d3e22d1f38fa9316..cb89d3cefd6525d017d7f7b4685be456cb1c7e44 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp index 231d612d781bc93c2a232f7e1478fddfb09c1b5d..91eefba0c10af87e7674cb053ee549a2799de7a9 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp index 165bc3957d36479b1cb3052c0085ade58fb71322..c20798f557e58e870f748d7900f1418f3a1a5779 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp index 832fc3b066066f9ac8faae30cebff714b73b53c9..3d9ad64b9b40a2271ed40608735bffb0133153f9 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp index 99e8712474996a6c30e8bf16ae69f5bc15925364..cf23d01bf234e2cdda3efb2b655879a4740f5c27 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp index f73e3dea84ebe2f4e70af03df967bcd3b2bd7450..498bf58fb3e6cfaf444fc92923ff261389ac0ac8 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -41,10 +41,11 @@ template , ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, @@ -58,8 +59,9 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_ DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, // Padded fallback kernel - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp index fd094719584147781d4246a508ada7104cdf12e5..744bd6456d28f719a28508f65bdbd78bb02d4120 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -41,10 +41,11 @@ template , ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, #if CK_WORKAROUND_SWDEV_388832 @@ -60,6 +61,7 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16 DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, // Padded fallback kernel + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> // clang-format on diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp index 53ad7ba5ffa9272da26189a45a75a9453af68f48..b342612d1c098070108696b7a76a8e7fe754fb51 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -45,6 +45,7 @@ using device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_ // #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | // #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, @@ -58,8 +59,7 @@ using device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_ DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, // Padded fallback kernel - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp index 291a127a6603bd5fe1256076355df76d133978e9..3fd0c07370d40407dd3ef754f55ccff9deb091df 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -45,6 +45,7 @@ using device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_ // #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | // #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, #if CK_WORKAROUND_SWDEV_388832 @@ -60,8 +61,7 @@ using device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_ DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, // Padded fallback kernel - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_bf16_instance.cpp index b62c8b99cbdb6cf42922e843c11a0832614b6ea7..bc95d2f1b1f1ad3734daaa19ed5dbb1faff5a6a4 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f16_instance.cpp index d05b8b592c29aceef387a4fdec8c02854501995f..fbc8d0bc6097a477fe9920201167a58142aa9b44 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f32_instance.cpp index e3ef95d12e17acfebc029ac6e1a878194691739f..bed38658a99c109370e6271dc94f785d94268106 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f64_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f64_instance.cpp index 41be396c24a28b5a7fbc47b10d7cf6ae5e6e07e9..fc5ec77e428dfbfb09a902245f27ca25b8bc0d35 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f64_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f64_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_bf16_instance.cpp index cd1e05b1133ad2a2515a936ef968667a03bc11be..4e38ee13b24dd150f1366ed7694e69c6969a2e2c 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f16_instance.cpp index 073dd583f97e0eed167e4bc4ef952530e82854a2..f087eb79824e3e36c6bc07a3d42b68f38a89c16b 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f32_instance.cpp index be63bd44c66c3204c66990ff6048bd895841c84f..d0f361401a500c16f1228d4df1d3de6714d1f43a 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f64_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f64_instance.cpp index fe87091e8d66e32af0832c1b013abce0911673b2..710d07b8280cef467d2fac8a70034129feb13fda 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f64_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f64_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_bf16_instance.cpp index 2e695afa9786d2e4fdcc6d9bf05c600c1a333181..8801c309f9138ccf321ba1dcac5f1718287e0a76 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/utility/tuple.hpp" diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_f16_instance.cpp index 9ec761e445ac30cb482c35be42846c3e46a13ea7..b674cfc423d81b5efce31169e1c81741ba156b9a 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/utility/tuple.hpp" diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_f32_instance.cpp index f0d26c36bedfa037a439e6226e8b1db4051d0fc3..05e3650887c3dfa84d1f03a8380d1144c7886acf 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/utility/tuple.hpp" diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_f64_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_f64_instance.cpp index 9e4066bb060e0a58342956860356ee589bd5e9c0..15a02af021a22b4e9e43186a7b3e18b938534dd8 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_f64_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_f64_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/utility/tuple.hpp" diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt index ffd6a6a7be2081429508fbbaec0edcf6bb1dbf39..d2a0a3d0fbe7a869b2d1603fe8879a201c1ca4b5 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt @@ -1,7 +1,13 @@ add_instance_library(device_contraction_bilinear_instance + #float device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp + #double + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp index ebbff88346b9dcb1eb4716b1b36f456b806ef1ca..5587db77e08ceb40242084853ed07236cc992605 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp index 980383f3e71b3f1058c295c9aba55c56f596134a..26262855ea8b6e3ba2c929e527e97dc5108495dd 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp index 2d4b6e3489a9ccce619b0a76dc097f04d72b0d57..befc0dcd10d07735ec046e8ef4d64bb54be4756c 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp index 7caa469f54b0169449250f000df038820e35a32f..e45b47cf94ed53de17b862e50ed2ef7a00e3e978 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f437a227d50b665b999e30b2d065ac6825e7ec04 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; +using F64_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 32, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 64, 16, 2, 2, 16, 16, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1> + // clang-format on + >; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..13fdbeb35c53a93947f739dd26f91885cff4a72b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; +using F64_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 1, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 1, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..95ef8c4929b00263c65034ac1ec48195efdc8fff --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; +using F64_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 2, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..290f81d7c9a926dd4e9c169e3515ebdbaa547788 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; +using F64_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt index 7ad6605486caa7fb4e60e717db6756f52f453bd8..31f6a0fcdc9a04d15d63ffc93e2dac550a0ed293 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt @@ -1,7 +1,13 @@ add_instance_library(device_contraction_scale_instance + #float device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp + #double + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp index 5118d0d033ee733c9b3bd7bc3524824c77aae667..16fd1cb407bf81741845c41baa580f208e98508d 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp index 655d4f0061ae1c417e610e9059f7e2bbd2048363..ff37bf7cceb40f75c8a54309341914539057d678 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp index a9d20be18bf5c52044b12f84c58f7824f8ef8142..8a1f6f93341edbb585157cd66308ab8f781bfedc 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp index a68f5c9718a1e547c62f16888a4fd0a0ed1d8197..d333f597268d55a195da671fcbcdaab2c03acfc0 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4c87b51a9e426945f049d471bf3591d1f222b72b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 64, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 32, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 64, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 32, 64, 16, 2, 2, 16, 16, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1> + // clang-format on + >; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fd3f57c6b6058f16e9e34158fefc71e3fb9c5f3c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 1, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 1, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1e53f0b2f51f23143b1d49d1b7d9b2d3e2b7b765 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 2, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d02d146a9bf98b25fe7de56148c8879dd62a328b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp index 5a5c8384227150dd8717433dce6c9ea2395896bd..e3e90c966dd6efd920ea0147d434f926bd07bd05 100644 --- a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp index e0f3d6199f8939bb0f1ce7f8efc48f2a92cfd3c6..81e9122d9506eb3ae78602c7beedbd2ce1b0b745 100644 --- a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp index 30537d9373b8175207b267666574033cae4e1a3c..dbc82168f49ad102886ce8aebaea72cd3ef5a99a 100644 --- a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp index 190c39b870b3cb33e9080c8b591884f4ec387347..3ac250f3e6a9d3fe8b3139aef5a79dc520269381 100644 --- a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp index e14cd558628772991da9b9d4ebb3dd5d852e3285..fcb858728e80c2cce14579f5a35409316deceeb4 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp index f001b83c1711a7185d71f5b2c7f795005b35b696..2baa6ae061473a44e962c597d2fabb9d62cabee4 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp index 83ba6a1c6bb8a3098da1b31a59a1d4a12b8b062e..28867fe19dd62f63d2e3df798a210a1045415fb8 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp index 1da9a81d904b3e43aaf5ace4f612c43c05e66ee0..40656e382b8074580de72fed78123c07b5e1e303 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp index 7c33df5e7687ee3e279b350bec9ca17902e84231..422e37e926fd7a839afc2cf884b141ca32737078 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp index a5f8629f2daef05ca1dbcbec2e8c4c668bf96a84..5993f6bd7a28c2992bf2e069880e84bf75612dfd 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp index 8076d6d35669114d84b922a59948b888dcdab506..bb9b696864c54265a8d3d119bcd55d82452a825e 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp index 33503b9f8ae00a92c25b178629736c7e8a5474c3..da96c79a6e9a44edd05928b54a865483f6a1e1a2 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp index c5e4bd199ea8322312af832b63a0ea559109131a..78e9c893c9e82a723d1794685c7d33d78129047d 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp index f43d13e309333bc86abf22465b7f035c1445321b..4663a9ac32ac3b6424d638c2d6bc7248651c240a 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp index 0ce6b04c4218190ccd3cad238f1bc90ec3002df9..0b1df52f8c1ce4503c154c4f7ce40b713455e839 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp index 76ab3189d7f210446b305a3daec80a2902de63ff..9969a8bbc90df83a31da50f8f064cbe0a83923f7 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp index f8c255088877e95e2cbbbacbfba6f9e1f4fc0291..e34ea06ff46ba090a212575e66dcfd454a44de66 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp index fe7152471ed54aa5334e820f993232912013c6de..3254fcfc26467edc3ec6b14d1838e784d7438529 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp index 04ce7c07639c968d90a8053f4b67b1335acc9b69..94b2a47e50b635e5b862c109f479660183d2bede 100644 --- a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp index 0251d9157f3b894208b17091ab6c79eec8bf605d..4244ab7b8770c9cbd60957c4fd66f004b7f5975c 100644 --- a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp index c2975727e48385207678abdcbc1fda776e761664..5c7db4ca3bc4dce89cd1d139722d1e91758f8d7a 100644 --- a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp index fc86d7302450a01263dc07d8f895c669795f3644..ebc56487a159afe8b32e5e321c6996c38250c32f 100644 --- a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/elementwise/device_normalize_instance.cpp b/library/src/tensor_operation_instance/gpu/elementwise/device_normalize_instance.cpp index 182037f15c66c2d196e86b84ca0f0571be2922e9..f2a5f0728ac1d8bc4bff78aa5476e872a7085163 100644 --- a/library/src/tensor_operation_instance/gpu/elementwise/device_normalize_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/elementwise/device_normalize_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -30,7 +30,12 @@ using device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances = std: //###################|| | functor| NDim| MPerThread| | | DeviceElementwiseImpl, Tuple, Normalize, 2, 8, Sequence<8, 1, 1, 8, 8>, Sequence<8> >, DeviceElementwiseImpl, Tuple, Normalize, 2, 4, Sequence<4, 1, 1, 4, 4>, Sequence<4> >, - DeviceElementwiseImpl, Tuple, Normalize, 2, 2, Sequence<2, 1, 1, 2, 2>, Sequence<2> >, + DeviceElementwiseImpl, Tuple, Normalize, 2, 2, Sequence<2, 1, 1, 2, 2>, Sequence<2> > + // clang-format on + >; + +using device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_generic_instance = std::tuple< + // clang-format off DeviceElementwiseImpl, Tuple, Normalize, 2, 1, Sequence<1, 1, 1, 1, 1>, Sequence<1> > // clang-format on >; @@ -39,6 +44,9 @@ void add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances( std::vector, Tuple, Normalize, 2>>& instances) { + add_device_operation_instances( + instances, device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_generic_instance{}); + add_device_operation_instances( instances, device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances{}); } diff --git a/library/src/tensor_operation_instance/gpu/elementwise_normalization/device_elementwise_normalization_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/elementwise_normalization/device_elementwise_normalization_f16_instance.cpp index b160d4fe1a3a271113899b2a8bc57e5535706280..3e2386ee03558ec6d346abc73a037e65e2d9cde3 100644 --- a/library/src/tensor_operation_instance/gpu/elementwise_normalization/device_elementwise_normalization_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/elementwise_normalization/device_elementwise_normalization_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp" diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt index e20d592c84ed8da549f29160c87f65579307096a..d66010af734581ae204015e7f3d663040d0b111e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -33,11 +33,19 @@ add_instance_library(device_gemm_instance device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp + device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp + device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instance.cpp device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp + device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp + device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp + device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp + device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp + device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp + device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp index 5d2f18e14e37be6b1c5528e8f43f23d1b6c88996..ea99a5a30ef80bdb954df78535a6cd63b8663bbb 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -34,7 +34,39 @@ using device_gemm_dl_f16_f16_f16_km_kn_mn_instances = std::tuple< // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // MPerBlock=128, NPerBlock=128 + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=128, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=128 + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<4, 2>, S<8, 2>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<2, 4>, S<2, 8>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=16, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=16 + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<1, 4>, S<1, 4>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=8, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=8 + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d444e29aa33538874b306caeab0fd78e177ce807 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=16 + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=8 + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp index 01e3b3793a487fd5eb6b6b5a6ca6a10e86bd0588..b83acfa8cb4db572bb33ccd20e6d38e7594f0234 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -34,7 +34,39 @@ using device_gemm_dl_f16_f16_f16_km_nk_mn_instances = std::tuple< // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // MPerBlock=128, NPerBlock=128 + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=128, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=128 + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<4, 2>, S<8, 2>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<2, 4>, S<2, 8>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=16, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=16 + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<1, 4>, S<1, 4>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=8, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=8 + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f8f8a0bd333748899bc52eed1f65ff6c56467a87 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=16 + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=8 + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp index 804e86a066ffa47c6d96de81cd4bef5644a5b131..d5800e03336e9161e01ef240a46aefbf9b13c2c0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -34,7 +34,39 @@ using device_gemm_dl_f16_f16_f16_mk_kn_mn_instances = std::tuple< // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // MPerBlock=128, NPerBlock=128 + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=128, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<8, 2>, S<4, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<2, 8>, S<2, 4>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=128 + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<2, 4>, S<2, 8>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<8, 1>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=16, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=16 + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<1, 4>, S<1, 4>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=8, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=8 + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..840a4fabe1c211689796debc86e8ce63e0c19b4e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=16 + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=8 + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp index 159fa90f7432248721f3bfef606f6df3b04584e9..abe52ce1da0edb191755d205a4e744c274ae56b1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -35,7 +35,39 @@ using device_gemm_dl_f16_f16_f16_mk_nk_mn_instances = // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // MPerBlock=128, NPerBlock=128 + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // // MPerBlock=128, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<8, 2>, S<4, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<2, 8>, S<2, 4>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // // MPerBlock=64, NPerBlock=128 + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<2, 4>, S<2, 8>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<8, 1>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=16, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=16 + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<1, 4>, S<1, 4>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=8, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=8 + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..82b1b5dc29fc4e5612236f2e0271597e405dfe80 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instance.cpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=16 + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=64 + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=8 + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp index d8e7798438dcaaa9d7c2e48e8074a2f58a635438..e696bfdcdc07f6799a00b9463837031559b715a5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp index 0034ac59c3888d8dcc1b2bf8e136ac4f02c2e6ad..d3ad7c60ec6f8cf2321aaa03e18e62ac28e65773 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp index 0b540b8b3492c355d75adf98cacc1826417db0b9..a56a36b0ab9cf0260a0366560d8816f18f1f5d7e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp index 4f6ff5111b1fa5dbf39e4c60d1770a319aa9dbba..63d55e81d6a5cb1157ea72a97e01599e6fc750ff 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp index a4208245e5756ccf3a2e6be29c3a5c5f73e97acc..3d9a265c29b35f09699fb12b578942a5c5f9ed89 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -31,7 +31,40 @@ using device_gemm_dl_i8_i8_i8_km_kn_mn_instances = std::tuple< // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // MPerBlock=128, NPerBlock=128 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=128, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=128 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=32, NPerBlock=32 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=16, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 4, 1, 4, 1, S<4, 2>, S<4, 2>, S<1, 1, 4, 4>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 4, 1, 4, 1, S<2, 4>, S<2, 4>, S<1, 1, 4, 4>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=16 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 4, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 4, 4, 1, 1, S<2, 4>, S<2, 4>, S<4, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 4, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=8, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 4, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 4, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=8 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 4, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 4, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 4, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 4, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 4, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 4, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0439201511f285a0d7eba405df72d1942a4fb49b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<4, 4>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=128, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=128 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=32, NPerBlock=32 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 2, 4, 1, S<4, 2>, S<2, 2>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 4, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 4, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 4, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=8 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 4, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 4, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 4, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 4, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 4, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp index 06fab7f68a2efef857107e3ef6038b4505e21712..240384d19c5eeebb86ac92be53d22605b1862801 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -31,7 +31,40 @@ using device_gemm_dl_i8_i8_i8_km_nk_mn_instances = std::tuple< // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // MPerBlock=128, NPerBlock=128 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=128, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=128 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=32, NPerBlock=32 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=16, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=16 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<2, 4>, S<2, 4>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=8, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=8 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..350834f7e515b93645e2f22285c19f9b9285304c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<4, 4>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=128, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=128 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=32, NPerBlock=32 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 2, 4, 1, S<4, 2>, S<2, 2>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=8 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp index b6d72fa221f4b946f3a25363a99b737a8a830428..e96905247d6a72e057bf13d821fdc17f78a380cf 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -31,7 +31,40 @@ using device_gemm_dl_i8_i8_i8_mk_kn_mn_instances = std::tuple< // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // MPerBlock=128, NPerBlock=128 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=128, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=128 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=32, NPerBlock=32 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=16, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=16 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<2, 4>, S<2, 4>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=8, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=8 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..27397527bc973c56bf15a8c35e86dec89184959c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<4, 4>, S<4, 2>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=128, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=128 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=32, NPerBlock=32 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 2, 4, 1, S<4, 2>, S<2, 2>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=8 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp index 67d2e3ce4b41c097fffb3099b97a8d398b44b743..124b818b28cfe1145a5936a9ea77500b980df0ad 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -31,7 +31,40 @@ using device_gemm_dl_i8_i8_i8_mk_nk_mn_instances = std::tuple< // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // MPerBlock=128, NPerBlock=128 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // // MPerBlock=128, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // // MPerBlock=64, NPerBlock=128 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=32, NPerBlock=32 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=16, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=16 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<2, 4>, S<2, 4>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=8, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=8 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b99f3f2b63902c4b832c86dcdc3f74376d972dbe --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<4, 4>, S<4, 2>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // // MPerBlock=128, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // // MPerBlock=64, NPerBlock=128 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=32, NPerBlock=32 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 2, 4, 1, S<4, 2>, S<2, 2>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=64 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=8 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp index 03eebf4ec3062cf5f85486288673fb96fb2761b1..2e884dfc8afba8d6c6ea88a8e3159c7a62660364 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp index 5d8de04cd9fa66931dde4ed0494fc429659506c3..2ca29b1e6f0791154d16ec1b5dd134d8d9c77c37 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp index 7b12b7cf1f42da6d8c321c1ac121166753b7ff59..706076098da912ea9e5a844ccb2797cbc850a01b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp index 730ffd4633de3c72cb6ab932126a18d8798418dd..5ac458a7b985d16faf90a53866ec4f9d2bc9c6b0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp index 619473ff0ce40a5150c59efdfd62bc905d7ce83f..a64412544721237542c729d0f9581acb39917555 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp index 8e06f9d26b4d3cf8a22259a54e103c7b8ce76190..44b684823391fd83caaf4b55a1ef41c19fe5b3c1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp index f9458b7483c5c29011b4aa063061436da23f9e36..23176269c2c84808ad05a874259694f22c6064d1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp index 77a03b746cc12e5f8cea8772c4432c9fb4e24110..31a9abe53a2312c25ebeb5157e81998054d2345b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp index ef8d7d4e40ef4de46857db4932e109780e52f044..201fd93110b604b1c6df849fcc5c88ff98a93bea 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp index cb65cc7b68b359ec015af3adcd3a7866983d492b..5d489b207efdf62de64d99469acfdbe49a13859f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp index 5b1014ed827d08f275b5da4c20abe14f153b9950..e09480d57bc20e54697fcd3c5ae761a66fda678b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp index e6f6add8bf45e98f6e9607b48a924e00ef642323..34065c334d1bee6064718a8dbac6f0b6f51c5826 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp index 80b3d03da26c49136c8220f8a433d8798ad751e1..95d7777a79371c4d9dfa1ec6ff09b16dc55a3807 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp index 93b3df1e572604e70d059c6cb622f8143eb2a92a..bf24bc76b1b8e3c64473df1d6b985712c53f4a20 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp index f10365d892455e3e36f5e05ec94b7a4f577c1fbc..023f987127cc61983fb11f53838cff9383f0f257 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp index a7a9eb62cef20c05cee9ff76e5d1c0b21fb95ab9..ffb199e58fe3a0200197ef1d119feec348a01502 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp index 9fb45b00365896be88e6aff214922c51325a4d65..90e979d891a5c4d1d9ca18653311afeda9029e5b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp index 18a78674e7ac4d79e6da3b035c8e9ea966618f16..8a81b77891fb2bb92175ef0d2e0e06f3fc38fc82 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp index cef6070af8a72a6cdb0616574736e722094f3a74..e1983add0459de966453bd86fb3d25e2fc68af2d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp index 1be70d6ca414f785ec2582fd90bc5d3d2617e4e3..47a180e1276b9290cd4a7506ab1a3df2917b2e6e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp index 6b8455ffa93fcceaf9d0e5d5595197f8bafddcbe..b8e994e91acc5f73469c257c4dd8d99bec193d97 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp index b9e28e3d7ef70b2e3aa785b048280fb4959ce312..a590413accba4e9e0dddd8df2952cd9d3fdb8b35 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp index 2b1a5a57bb8b22b9235fdd02fe3f9da659b13c0d..1d010d1b0792dfc62f0a1de235558ab8e4e71fd0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp index 301d3b55b59c0a6818b27570fe1c9a7b9ff44946..f108b753421598c376eefb8cfa52b0a4f1522920 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp index cd16f35ff233fb04c79d50888f54b806e79e611a..b0b4bc012df347cf3e487d96ecef2e03826f5e77 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp index 39166698473dbd05010f960e82366012ffaa1863..df3bd94fcafa5062b597584c3aadf2772c376f39 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp index 0a623034ef65a1cfd0d220b9ad2253ab0b298620..73b4e776668c969fb6ddcd38f74bd1a74fea48c8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp index 5ef8d08de90a65ae04f552fc87447e3b7887bc92..76137a1c3e1e2cb82c9e15ec9bc71e4abe76391e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp index c9557bae893e1affc0f2138aaa9cfca206e267d1..f0158d8f3d51fbb773988d02c14c329e9243ebf9 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp index 463e0865c0ae6980c534005d1ace62750645a74c..125fbc21a2b9c141adc8d11b0727667ca062fada 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -36,6 +36,17 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b), d0, d1) // outout: e[m, n] // input: a[k, m], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_generic_instance = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances = std::tuple< // clang-format off @@ -139,6 +150,9 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn PassThrough, AddAddFastGelu>>>& instances) { + add_device_operation_instances( + instances, + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_generic_instance{}); add_device_operation_instances( instances, device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp index b71ff1b9972a97276a0fedf0106814a54ba19a0c..cc33692d7378178903fe2b0a9a43180571d70276 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -36,6 +36,17 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b), d0, d1) // outout: e[m, n] // input: a[k, m], b[n, k], d0[m, n], d1[m, n] +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_generic_instance = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances = std::tuple< // clang-format off @@ -139,6 +150,9 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn PassThrough, AddAddFastGelu>>>& instances) { + add_device_operation_instances( + instances, + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_generic_instance{}); add_device_operation_instances( instances, device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp index 9060c9b1b084512f038ccd5c5b0c3573ac32f0a9..704787a080ee0af7ea74f9f75a290a69cf73adf0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -36,6 +36,17 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b), d0, d1) // outout: e[m, n] // input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_generic_instance = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances = std::tuple< // clang-format off @@ -139,6 +150,9 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn PassThrough, AddAddFastGelu>>>& instances) { + add_device_operation_instances( + instances, + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_generic_instance{}); add_device_operation_instances( instances, device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp index 81cf01d6a9d70346e8ea2387220fb03258e6e580..d64c9ec5e15aa85c54d5c852755be63ca0597327 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -36,6 +36,17 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b), d0, d1) // outout: e[m, n] // input: a[m, k], b[n, k], d0[m, n], d1[m ,n] +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_generic_instance = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances = std::tuple< // clang-format off @@ -130,6 +141,9 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn PassThrough, AddAddFastGelu>>>& instances) { + add_device_operation_instances( + instances, + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_generic_instance{}); add_device_operation_instances( instances, device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp index 4da85cc46eb4ded15bcd63dd7b0d89864637d2b0..e68bd8e7e4aa032dc56007e22aba0f90ca0c6af8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" @@ -21,6 +21,17 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b), d0) // outout: e[m, n] // input: a[k, m], b[k, n], d0[m, n] +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_generic_instance = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances = std::tuple< // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| @@ -123,6 +134,9 @@ void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_inst PassThrough, AddFastGelu>>>& instances) { + add_device_operation_instances( + instances, + device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_generic_instance{}); add_device_operation_instances( instances, device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances{}); add_device_operation_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp index ab83e4baabf80242bfbf569fb13b5d39a3686ad7..5aaa2e8fe509a18ac360859f0cdcf67535b75fd8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" @@ -21,6 +21,17 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b), d0, d1) // outout: e[m, n] // input: a[k, m], b[n, k], d0[m, n], d1[m, n] +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_generic_instance = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances = std::tuple< // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| @@ -123,6 +134,9 @@ void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_inst PassThrough, AddFastGelu>>>& instances) { + add_device_operation_instances( + instances, + device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_generic_instance{}); add_device_operation_instances( instances, device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances{}); add_device_operation_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp index a4cd3fadbe93f4231388e9efccbe28a2c69a6f1b..7a2a3dbaf3d82cbed71e91e2bebe8fb3f3dd4ea0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" @@ -21,6 +21,17 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b), d0, d1) // outout: e[m, n] // input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_generic_instance = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple< // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| @@ -123,6 +134,9 @@ void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_inst PassThrough, AddFastGelu>>>& instances) { + add_device_operation_instances( + instances, + device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_generic_instance{}); add_device_operation_instances( instances, device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances{}); add_device_operation_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp index 207e76ffe5f18f9a4d64a0083c27829933e7a5fa..fa3360997886de27f75b89457d8e1ba9ae4471d1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" @@ -21,6 +21,17 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b), d0, d1) // outout: e[m, n] // input: a[m, k], b[n, k], d0[m, n], d1[m ,n] +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_generic_instance = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances = std::tuple< // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| @@ -114,6 +125,9 @@ void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_inst PassThrough, AddFastGelu>>>& instances) { + add_device_operation_instances( + instances, + device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_generic_instance{}); add_device_operation_instances( instances, device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances{}); add_device_operation_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp index 3f30937ff5b3b7235efc93dc6287a494ed783dee..d9ae7817914dbcb7dcae38254afef6f02e54a4da 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp index d91e6c63bea06d70513106cf7eac171b892498b9..27c6cbe8ab71398f9c3bc03bb977c71af397836d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp index 143321542353e4d84d7275f375d9fd49103bc8cb..ff5f0e94a1732e3b36bc9295c4f1ce9a24fa467f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp index 09acc7c0f75bcd76412755a590605833c2bc3b5e..dba625e0ada1d4f09cb26a5860849a707a35cb97 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp index 47b8d23424d1dd59099441d3da8d80c1b889daf2..28a452c1a1124a326337cf4c21183035294fda1d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp index efa030ec4952acabef2a5fbd879adc07eeb2ecbf..13366238d6d9008d7ed425485670df5229f71cee 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp index f2735020e65bdbeda2c31df3c451d16903a0e320..8a4889ee83eabcae6466bdd5dd67d268c9d54287 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp index 7d4aae928b31d5c086ea717ad4be492a577b8cb1..fc3cbcf9055892bfcf63dcf6d3d3c4515026ca71 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp index e8747af4828a2492f6db8e23b575de4efb7ce90e..bfb95bce85cb8696e645ec6b5626abf84456916f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp index ed54c3a9bf58346a8829f3d7f275c6668e00ed1d..d0352339cf38fecf9c3b3d4fd8a91729ed53495f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp index da7eae637bfb16580b4ec98e479768c2649adad7..d5b298ab217b4269beca858f0d54e319bf6dbf4f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp index 34345095e0c2489f44437ea58f58b2891cc91ff1..80c8f018f416154d0b2ccd26cb9f8822233a43e4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp index 55461dfba783de45a2991836535b3b652cc49a44..74ec9e1f8d28a77b0f78181cd6c05ea93b140e16 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp index 405e69975ca9ffec97999bce38fd73d953e139ed..eb98b3e7e871a1c5f4129032031175443ddc4de0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp index 9af31b3a129d2d1bf1ff6f32d0ee8e252cbdfa29..5f4a90125ac827ccfa406b5bf81f11313ce54258 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp index 841b7a1d47ce48b8486f0af04ff80571e4601d1a..38e3897d6a552d8a2f797b9f3b4baaa98636a53d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp index 9f7f643beb6832594bdc393777e5e08663a1f94e..803c44c7f53352109c28510c75f59e98165d6d8e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" @@ -21,6 +21,16 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b)) // outout: e[m, n] // input: a[k, m], b[k, n] +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_generic_instance = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::tuple< // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| @@ -122,6 +132,8 @@ void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( PassThrough, FastGelu>>>& instances) { + add_device_operation_instances( + instances, device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_generic_instance{}); add_device_operation_instances( instances, device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances{}); add_device_operation_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp index c8e9f35d240da45fb902a2d02cd73d2eab5167e1..9b9ef3db24cb28e0783b5003678beeab01dd22f1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" @@ -21,6 +21,16 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b)) // outout: e[m, n] // input: a[k, m], b[k, n] +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_generic_instance = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::tuple< // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| @@ -122,6 +132,8 @@ void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( PassThrough, FastGelu>>>& instances) { + add_device_operation_instances( + instances, device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_generic_instance{}); add_device_operation_instances( instances, device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances{}); add_device_operation_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp index 5f804d45a560ac3b7b519cb3ab1840214e392d39..1a0b6c9d1c62d1aeb378ca140fd09959d213533a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" @@ -21,6 +21,16 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b)) // outout: e[m, n] // input: a[m, k], b[k, n] +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_generic_instance = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple< // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| @@ -122,6 +132,8 @@ void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( PassThrough, FastGelu>>>& instances) { + add_device_operation_instances( + instances, device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_generic_instance{}); add_device_operation_instances( instances, device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances{}); add_device_operation_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp index 60cb138f5654bb6f2b10ea5ba0a03191a49906ab..18b1c0e993131b62d0692769090facd14034ea8e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" @@ -21,6 +21,16 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b)) // outout: e[m, n] // input: a[m, k], b[n, k] +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_generic_instance = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple< // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| @@ -113,6 +123,8 @@ void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( PassThrough, FastGelu>>>& instances) { + add_device_operation_instances( + instances, device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_generic_instance{}); add_device_operation_instances( instances, device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances{}); add_device_operation_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp index 59e2b2da8641b9fb094750077d760e9f611d9c98..f0e7b6ab43c51abeb4e79e9e30b104bf05f82a07 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp index bb09bf8b8e82f6c2056881f37b9e65c73657f0e5..56815b9ac5fa4d89b649089c73fbfeddc985cc8e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp index 0a3b566de64505adb19f75ffe5e126cfed5da438..e66d46a26995ec521df8fe15a7fd01277ebaff0d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp index 2b17e47b1cce153654d74ca7497b5c04117071de..fb1dfac69cd525ee7b0caa3d3d7b732ba5042bd8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp index e178d3b0adeb563c36d6a8f1f3936d048df23f75..fed2cbbfb980ab27e8810b643874fc9b308b21b7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp index 52be9fe709bd710f3e05f381e342deef5c7affae..44ac4c08cd97881f98f0c36a9bd5d42cd5078327 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp index 9b5ff404845d5ae00b5664ee44a405e979ce52d2..30a2bf36d1cef771764e2832fc8d8a91b7120588 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -26,7 +26,8 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; // Compilation parameters for a[m, k] * b[k, n] = c[m, n] using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple< @@ -35,14 +36,22 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple< //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8> + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp index 7fc35c4198eb49bf2625f6cfa8eb80a001fdc086..4b2a2dbdc73aee32a427f91470610325a5bb3080 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp index f27b2199e0c9dcb1e91566cb0aa0a05edaef5b0e..9d15ccd362774c77a9613046995d44761b8ccb49 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp index b9a1095570aaf33aa8f3f93710cbb6d3ae276430..4e9ad58742c552f437dd4b8174cb675984dfc681 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp index 44e5f597d0d714438c97a9ce7d591736f4fe7f48..330e5aff907af00920dcc4f3cb1f5f2a3b50e6fa 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp index f3a9063f7fb847b9c303ea836f97ed4cf3173ab9..0db3a15d2a4816cca200820346ccf7c9a711fa31 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp index 05ba449246e8c6f68182eb707a3e9da0974e2078..ccbfaeaf4f3d0c1d6b3f25adf8765466195af971 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp index 7a610a747cc9cc7555bd18b17fd499d19cd7c439..e10de67f94bf4747915552c54ff47874b53eb721 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp index 90e074f052c6f935bcbf862250670289e0634617..8b47e82f6d3574fd04c5ed69d82dd4b6f2584779 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp index 74aebf1031d5b02db54fece32fa693e0c532287d..5aa50adb31ecf5c8c636d0bb34946b166c5919e6 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp index 361ea8f4ee9ff8144fd2392754fdaa53f380e40b..333b40c71ce1c755f7d7862c9ddae8013392ea8a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instance.cpp index 3145b716402e4e5cf8e542dc991bc7db0a655e74..506a93ae9ba07d7a53e770809f2c47399432e476 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instance.cpp index cde93f902c95d1c28f85b7a7627089fe3e97f841..30084f16d1d4c86ee2778b0d3526c2e593d39259 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt index 3b2968d48efd3e0061c36df1febb290b64264f73..85ec0f55aa78db969e8731213094f786a281d8d1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt @@ -1,3 +1,8 @@ add_instance_library(device_grouped_conv2d_bwd_data_instance device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp + device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp + device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp + device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp + device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..01cd23b2e3b1b857e1f2f629ba92f0a49efea3b9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "device_grouped_conv2d_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k] +void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv2d_bwd_data_xdl_bf16_instances{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv2d_bwd_data_xdl_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp index 3d604d42cc3a9fcbf8e9750ab5e3b2064f0bc161..bb0fd36e44f16987b3a11f7d8699a8510566fd59 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -1,80 +1,14 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "device_grouped_conv2d_bwd_data_xdl_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { - -using F16 = ck::half_t; -using F32 = float; - -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using GNHWC = ck::tensor_layout::convolution::GNHWC; -using GKYXC = ck::tensor_layout::convolution::GKYXC; -using GNHWK = ck::tensor_layout::convolution::GNHWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvBwdDataDefault = - ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; - -static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = - ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; - -using device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances = std::tuple< - // clang-format off - // 1. Default - // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| - // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // 2. Filter1x1Stride1Pad0 - // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| - // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - +// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k] void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances( std::vector>>& instances) { + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv2d_bwd_data_xdl_f16_instances{}); + // 2. Filter1x1Stride1Pad0 add_device_operation_instances( - instances, device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances{}); + instances, + device_grouped_conv2d_bwd_data_xdl_f16_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b671017f51c8dc31b256a95f1214ed17fb03fc63 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "device_grouped_conv2d_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k] +void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv2d_bwd_data_xdl_f32_instances{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv2d_bwd_data_xdl_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_instance.hpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_instance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bfe6e47d19218519f243ade0e36d8ea0c8567be8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_instance.hpp @@ -0,0 +1,141 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using NHWGC = ck::tensor_layout::convolution::NHWGC; +using GNHWC = ck::tensor_layout::convolution::GNHWC; + +using GKYXC = ck::tensor_layout::convolution::GKYXC; + +using NHWGK = ck::tensor_layout::convolution::NHWGK; +using GNHWK = ck::tensor_layout::convolution::GNHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// f16_f16_f32_f16 +template +using device_grouped_conv2d_bwd_data_xdl_f16_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8> + +#ifdef CK_WORKAROUND_SWDEV_3318619 + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, +#endif + // clang-format on + >; + +// bf16_bf16_f32_bf16 +template +using device_grouped_conv2d_bwd_data_xdl_bf16_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8> + +#ifdef CK_WORKAROUND_SWDEV_3318619 + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> +#endif + // clang-format on + >; + +// f32_f32_f32_f32 +template +using device_grouped_conv2d_bwd_data_xdl_f32_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4> + +#ifdef CK_WORKAROUND_SWDEV_3318619 + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>, + // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, +#endif + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0b1fa63ddb3ce300db54be0beeade361f86bd2cc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "device_grouped_conv2d_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv2d_bwd_data_xdl_bf16_instances{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv2d_bwd_data_xdl_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1c8fc57e1c778be8b25426b2a62f06767b609b45 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "device_grouped_conv2d_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv2d_bwd_data_xdl_f16_instances{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv2d_bwd_data_xdl_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fcdfe8d42bf05aaeee6909380f3413ac9ae9b16a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "device_grouped_conv2d_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv2d_bwd_data_xdl_f32_instances{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv2d_bwd_data_xdl_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp index ede21f1f4f7124156540518bf4252d3526929606..c2c0fc553cb38d299a1fbbc9990a53817c7105cd 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp index 99e556618c3bfc535807ae8595a33e0f6b6d5390..5be7443eca6ec8c7acd76beeb2c790b711810b89 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp index 15871a28c3a2a85c053fc3259f9de67b5694ede6..2828b432ff6772d802d850bbcc17ee58a3af909c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 5ef1b6866efce4ade692d86e50a3240fdd6ac0ae..a36e1b47caa6c35082d64523941134ecc9ffb320 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -3,11 +3,11 @@ add_instance_library(device_grouped_conv2d_fwd_instance device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp - device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instance.cpp # NHWGC, GKYXC, NHWGK + device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp - #dl + device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp + #dl device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp - device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..85a7a5be257365f37234c3b723d47a86a253b24d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_common.hpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using NHWGC = ck::tensor_layout::convolution::NHWGC; +using GNHWC = ck::tensor_layout::convolution::GNHWC; + +using GKYXC = ck::tensor_layout::convolution::GKYXC; + +using NHWGK = ck::tensor_layout::convolution::NHWGK; +using GNHWK = ck::tensor_layout::convolution::GNHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp index fc18b3c7358c44b3bdbf00db99274f432beefea4..5eb881549cf37c7aba2dae0e5ac1aa3402a9c3a9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -1,100 +1,54 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "device_grouped_conv2d_fwd_dl_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -using InDataType = ck::half_t; -using WeiDataType = ck::half_t; -using AccDataType = float; -using OutDataType = ck::half_t; - -using Empty_Tuple = ck::Tuple<>; -template -using S = ck::Sequence; - -using InElementOp = ck::tensor_operation::element_wise::PassThrough; -using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; -using OutElementOp = ck::tensor_operation::element_wise::PassThrough; - -using InLayout = ck::tensor_layout::convolution::GNHWC; -using WeiLayout = ck::tensor_layout::convolution::GKYXC; -using OutLayout = ck::tensor_layout::convolution::GNHWK; - -static constexpr auto ConvSpec = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; -static constexpr auto Filter1x1Pad0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; -static constexpr auto Filter1x1Stride1Pad0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances = std::tuple< - // clang-format off - // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; -using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_Filter1x1Pad0_instances = std::tuple< - // clang-format off - // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Pad0, GemmPadingSpec, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_Filter1x1Stride1Pad0_instances = - std::tuple< - // clang-format off - // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Stride1Pad0, GemmPadingSpec, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances( std::vector>>& instances) + F16, + PassThrough, + PassThrough, + PassThrough>>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances{}); + device_grouped_conv2d_fwd_dl_f16_instances{}); - add_device_operation_instances( - instances, device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_Filter1x1Pad0_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_dl_f16_instances{}); - add_device_operation_instances( - instances, - device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_Filter1x1Stride1Pad0_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_dl_f16_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp index 648b39637ebe7a49d42c5dd2620834b76a04cc9b..4157853c41767f8ef705c864482fa56724d16e19 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp @@ -1,104 +1,54 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "device_grouped_conv2d_fwd_dl_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -using InDataType = float; -using WeiDataType = float; -using AccDataType = float; -using OutDataType = float; - -using Empty_Tuple = ck::Tuple<>; -template -using S = ck::Sequence; - -using InElementOp = ck::tensor_operation::element_wise::PassThrough; -using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; -using OutElementOp = ck::tensor_operation::element_wise::PassThrough; - -using InLayout = ck::tensor_layout::convolution::GNHWC; -using WeiLayout = ck::tensor_layout::convolution::GKYXC; -using OutLayout = ck::tensor_layout::convolution::GNHWK; - -static constexpr auto ConvSpec = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; -static constexpr auto Filter1x1Pad0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; -static constexpr auto Filter1x1Stride1Pad0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances = std::tuple< - // clang-format off - // clang-format off - // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_Filter1x1Pad0_instances = std::tuple< - // clang-format off - // clang-format off - // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Pad0, GemmPadingSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_Filter1x1Stride1Pad0_instances = - std::tuple< - // clang-format off - // clang-format off - // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Stride1Pad0, GemmPadingSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances( std::vector>>& instances) + F32, + PassThrough, + PassThrough, + PassThrough>>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances{}); + device_grouped_conv2d_fwd_dl_f32_instances{}); - add_device_operation_instances( - instances, device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_Filter1x1Pad0_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_dl_f32_instances{}); - add_device_operation_instances( - instances, - device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_Filter1x1Stride1Pad0_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_dl_f32_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instance.cpp deleted file mode 100644 index 1cb5d06999512687cc1ef0e9fee1444096bcf1c2..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instance.cpp +++ /dev/null @@ -1,104 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using InDataType = int8_t; -using WeiDataType = int8_t; -using AccDataType = int32_t; -using OutDataType = int8_t; - -using Empty_Tuple = ck::Tuple<>; -template -using S = ck::Sequence; - -using InElementOp = ck::tensor_operation::element_wise::PassThrough; -using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; -using OutElementOp = ck::tensor_operation::element_wise::PassThrough; - -using InLayout = ck::tensor_layout::convolution::GNHWC; -using WeiLayout = ck::tensor_layout::convolution::GKYXC; -using OutLayout = ck::tensor_layout::convolution::GNHWK; - -static constexpr auto ConvSpec = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; -static constexpr auto Filter1x1Pad0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; -static constexpr auto Filter1x1Stride1Pad0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances = std::tuple< - // clang-format off - // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_Filter1x1Pad0_instances = std::tuple< - // clang-format off - // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Pad0, GemmPadingSpec, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_Filter1x1Stride1Pad0_instances = - std::tuple< - // clang-format off - // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Stride1Pad0, GemmPadingSpec, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances{}); - - add_device_operation_instances( - instances, device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_Filter1x1Pad0_instances{}); - - add_device_operation_instances( - instances, - device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_Filter1x1Stride1Pad0_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_instance.hpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_instance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3d3f9b17930e051638f7e017bf7b4250fee19eac --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_instance.hpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "device_grouped_conv2d_fwd_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using device_grouped_conv2d_fwd_dl_f16_instances = std::tuple< + // clang-format off + // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +template +using device_grouped_conv2d_fwd_dl_f32_instances = std::tuple< + // clang-format off + // clang-format off + // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F32, F32, DsDatatype, F32, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp index 29f3310314548b332d6fab41f7ddcebb215775eb..bc7d577b69b623901b71e4127ee67d55676947f2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp @@ -1,137 +1,14 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "device_grouped_conv2d_fwd_xdl_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { - -using BF16 = ck::bhalf_t; -using F32 = float; - -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using GNHWC = ck::tensor_layout::convolution::GNHWC; -using GKYXC = ck::tensor_layout::convolution::GKYXC; -using GNHWK = ck::tensor_layout::convolution::GNHWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto ConvFwdOddC = - ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - // Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] -using device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances = - std::tuple< - // clang-format off - // Default - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Stride1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // OddC - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 256, 64, 32, 8, 8, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( std::vector>>& instances) { - add_device_operation_instances( - instances, device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_bf16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_bf16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_bf16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_bf16_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp index 6a4a3d2a45e2a7c11cedab653bf4bed447537115..55bf6da9b9e1b876bce918ee8d611c69a053d820 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -1,137 +1,14 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "device_grouped_conv2d_fwd_xdl_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { - -using F16 = ck::half_t; -using F32 = float; - -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using GNHWC = ck::tensor_layout::convolution::GNHWC; -using GKYXC = ck::tensor_layout::convolution::GKYXC; -using GNHWK = ck::tensor_layout::convolution::GNHWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto ConvFwdOddC = - ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - // Compilation parameters for in[g, n, hi ,wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] -using device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances = - std::tuple< - // clang-format off - // Default - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Stride1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // OddC - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 256, 64, 32, 8, 8, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances( std::vector>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances{}); + device_grouped_conv2d_fwd_xdl_f16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f16_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp index 1fec35fd9fffa59afe9c7bdc81e3dd7ada2dde8f..202cdd6b446b20277eccedcd89bbf2346b906241 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp @@ -1,109 +1,14 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "device_grouped_conv2d_fwd_xdl_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { - -using F32 = float; - -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using GNHWC = ck::tensor_layout::convolution::GNHWC; -using GKYXC = ck::tensor_layout::convolution::GKYXC; -using GNHWK = ck::tensor_layout::convolution::GNHWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - // Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] -using device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances = - std::tuple< - // clang-format off - // Default - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - - // Filter1x1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - - // Filter1x1Stride1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> - // clang-format on - >; - void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( std::vector>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances{}); + device_grouped_conv2d_fwd_xdl_f32_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f32_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f32_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f32_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instance.cpp deleted file mode 100644 index 59b012134058cbd64a3548ea3892cd477660d4d9..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instance.cpp +++ /dev/null @@ -1,125 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using GNHWC = ck::tensor_layout::convolution::GNHWC; -using GKYXC = ck::tensor_layout::convolution::GKYXC; -using GNHWK = ck::tensor_layout::convolution::GNHWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] -using device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances = std::tuple< - // clang-format off - // Default - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Stride1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - -void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_instance.hpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_instance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..07bea1c03cafca9e7cd3069c1fb04c0e1d6cf9f4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_instance.hpp @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "device_grouped_conv2d_fwd_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using device_grouped_conv2d_fwd_xdl_f16_instances = + std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +template +using device_grouped_conv2d_fwd_xdl_bf16_instances = + std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +template +using device_grouped_conv2d_fwd_xdl_f32_instances = + std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..82edf896db1c932c2d3b19d9361e1294801b4035 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "device_grouped_conv2d_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_bf16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_bf16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_bf16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp index 8aca730433060d045e9c2045e6a366b08e0fe5cf..4bd3236caf31392c6626b0df299d02082f55c6e1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -1,137 +1,14 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "device_grouped_conv2d_fwd_xdl_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { - -using F16 = ck::half_t; -using F32 = float; - -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using NHWGC = ck::tensor_layout::convolution::NHWGC; -using GKYXC = ck::tensor_layout::convolution::GKYXC; -using NHWGK = ck::tensor_layout::convolution::NHWGK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto ConvFwdOddC = - ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -using device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances = - std::tuple< - // clang-format off - // Default - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Stride1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // OddC - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 256, 64, 32, 8, 8, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( std::vector>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances{}); + device_grouped_conv2d_fwd_xdl_f16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f16_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4f5bdb202349ad0a5c10bbde97caefda0eca61f2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "device_grouped_conv2d_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f32_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f32_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f32_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp index e48db4a5314528bb45f4444f93da4ba8872e86d8..7ae87eed5d2504efa3c24b8c669c36ba19b7bc79 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp index 1655850ec148568734cbdbace8eefe8cb6487e0c..ab07341b59b3c9a4bacfcc842cf8ab2e8b09372b 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp index aba46b7ebeb90afb8881fc7ad4d8a2e1281eb146..15045bedd87a1d76eb2804e7e95dfac4db3e12ff 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 90efc09ee75f692a0664f3869cf0a6b18a7c2ce1..cd209dbf9e11e0031848fac0148f8e7e848ed480 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -4,8 +4,8 @@ add_instance_library(device_grouped_conv3d_fwd_instance device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp - device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instance.cpp - device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instance.cpp - device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instance.cpp - device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instance.cpp + device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp index b4ae8b6ce5f2ba3d8f3569449934d84fc631252c..c3b100ea69c8312dd72a724d8b9c4e242efa86df 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp index 061674bd829840def99f4dcf317886844d398fea..ca488e9dc429e7cbb943979ce79f9f2dda157c48 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp index ed7e5476760983e8ca5a17475585ec5eb9e834f2..6087b1b187cd4af0d45402379dc391479ea936c5 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp index bf5fa306013e485e6dd57faa011bee8217001148..fd8c47deb38ab31641622317c28c9d7303c85392 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp similarity index 91% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp index 8c384937352a012a701a21a3bb572d6b6fb7e57c..d2cacf4539de6ed0f0e6b85397244f913bd1d823 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -26,7 +26,7 @@ template using S = ck::Sequence; using NDHWGC = ck::tensor_layout::convolution::NDHWGC; -using KZYXGC = ck::tensor_layout::convolution::KZYXGC; +using GKZYXC = ck::tensor_layout::convolution::GKZYXC; using NDHWGK = ck::tensor_layout::convolution::NDHWGK; using PassThrough = ck::tensor_operation::element_wise::PassThrough; @@ -43,7 +43,7 @@ static constexpr auto ConvFwd1x1S1P0 = static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; // in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, wo, k] -using device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances = +using device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances = std::tuple< // clang-format off // Default @@ -51,64 +51,64 @@ using device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances = //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, // Filter1x1Pad0 //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, // Filter1x1Stride1Pad0 //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> // clang-format on >; -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances( +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( std::vector>>& instances) { add_device_operation_instances( - instances, device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances{}); + instances, device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp similarity index 91% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp index 487cd22721a086bbeca32dd56639614903596905..6354f011129b2128e28dd8e64940d102c31e04b7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -26,7 +26,7 @@ template using S = ck::Sequence; using NDHWGC = ck::tensor_layout::convolution::NDHWGC; -using KZYXGC = ck::tensor_layout::convolution::KZYXGC; +using GKZYXC = ck::tensor_layout::convolution::GKZYXC; using NDHWGK = ck::tensor_layout::convolution::NDHWGK; using PassThrough = ck::tensor_operation::element_wise::PassThrough; @@ -43,7 +43,7 @@ static constexpr auto ConvFwd1x1S1P0 = static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; // in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, wo, k] -using device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances = +using device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances = std::tuple< // clang-format off // Default @@ -51,64 +51,64 @@ using device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances = //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, // Filter1x1Pad0 //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, // Filter1x1Stride1Pad0 //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> // clang-format on >; -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances( +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( std::vector>>& instances) { add_device_operation_instances( - instances, device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances{}); + instances, device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp similarity index 91% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp index d497cd57edfd85eda86d9641f49fa219b45e72d5..381ac5ef087a5202a05388f19ccc7245832fdcdb 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -25,7 +25,7 @@ template using S = ck::Sequence; using NDHWGC = ck::tensor_layout::convolution::NDHWGC; -using KZYXGC = ck::tensor_layout::convolution::KZYXGC; +using GKZYXC = ck::tensor_layout::convolution::GKZYXC; using NDHWGK = ck::tensor_layout::convolution::NDHWGK; using PassThrough = ck::tensor_operation::element_wise::PassThrough; @@ -42,7 +42,7 @@ static constexpr auto ConvFwd1x1S1P0 = static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; // in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, wo, k] -using device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances = +using device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances = std::tuple< // clang-format off // Default @@ -50,64 +50,64 @@ using device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances = //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, // Filter1x1Pad0 //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, // Filter1x1Stride1Pad0 //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> // clang-format on >; -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances( +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector>>& instances) { add_device_operation_instances( - instances, device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances{}); + instances, device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp similarity index 91% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp index 2e53fbbda5c5e9d4f2712b17fce2f57f7d140842..6bd53d869343a0fdae7bd890a60ab854153b5a58 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -23,7 +23,7 @@ template using S = ck::Sequence; using NDHWGC = ck::tensor_layout::convolution::NDHWGC; -using KZYXGC = ck::tensor_layout::convolution::KZYXGC; +using GKZYXC = ck::tensor_layout::convolution::GKZYXC; using NDHWGK = ck::tensor_layout::convolution::NDHWGK; using PassThrough = ck::tensor_operation::element_wise::PassThrough; @@ -40,71 +40,71 @@ static constexpr auto ConvFwd1x1S1P0 = static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; // in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, wo, k] -using device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances = std::tuple< +using device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances = std::tuple< // clang-format off // Default //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, // Filter1x1Pad0 //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, // Filter1x1Stride1Pad0 //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, GKZYXC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> // clang-format on >; -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances( +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances( std::vector>>& instances) { add_device_operation_instances( - instances, device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances{}); + instances, device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt index 82beb2ace282742a17a14881b0520677b1324c74..b973b70aacc86a5e5baadd2e5b72f17704754fa4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt @@ -3,4 +3,8 @@ add_instance_library(device_grouped_gemm_instance device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp + device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp + device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp + device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp index b550bb28716e14d6618dd06ed06802211c92c866..aa161e51c989b2128a9a1e1c5ad055d2678ae737 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp index a3f9c7a9e73f6f6db24a6526bf5e8195842a33ac..c454deac1a3a045f51a853d1546a70500360cc15 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp index a93cb7fc84e9908613cc9e180198f60800524ec9..c829e8863d2c1156c1754ee3ae73c47a5f84a07a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -64,6 +64,7 @@ using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_tile_instances = st //###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp index 2ace1b24320b215fc5d45f21aec69ddcbc346cff..fb30e7a97317aa74d589d0987011b54dc67dc62b 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8642562fa38f25a7e487d7b4194ee8f8f16beb91 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// a[m, k] * b[k, n] = e[m, n] +using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Currently AK1 must equal BK1 ! + // DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + // DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + // DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + // DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16,16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..83b31b07cf45c068be8711838ff43ab2ad0101a6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_tile_instances = std::tuple< + // clang-format off + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Currently AK1 must equal BK1 ! + // DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16,16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + // DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aa6365cd98c039c36007b875d425e3cef37ba782 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// a[m, k] * b[n, k] = e[m, n] +using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f4460b360be563123fdf7e02c43b49008dfc43ea --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple< + // clang-format off + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 32, 8, 8, 32, 32, 1, 4, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp index c2f5f00c7acb62da4a901bc965d237aeb46bcf3a..f4086b6eac4dd74903fbddb955d5bd7f41ee748a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp index 476d4ce1f87a0d2d5c109ba847273edbb31cd8e1..d68eb76144579b073f8119d431910696e6b4af9e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp index 1023fa48102b7078b320be22beb19881cb678cba..2dfb8caace5144fa29934af425b005c28e9d77c1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp index 6b065c0f82ac87adb5acbbf21cd470392890bd6d..598a0b0e29c22effb655b09247bc84ef6c1dd268 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt index aa0cc1148051c8000b588cc72d0f15f38a412112..176fb2fbee77418b877b0b76d0a2086dcce2da92 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt @@ -1,4 +1,11 @@ add_instance_library(device_normalization_instance - device_normalization_f16_instance.cpp - device_normalization_f32_instance.cpp + device_layernorm2d_f16_instance.cpp + device_layernorm2d_f32_instance.cpp + device_layernorm4d_f16_instance.cpp + device_layernorm4d_f32_instance.cpp + device_groupnorm_f16_instance.cpp + device_groupnorm_f32_instance.cpp + device_groupnorm_swish_f16_instance.cpp + device_groupnorm_swish_f32_instance.cpp + device_groupnorm_swish_f16_f32_f32_f16_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e3820462cf85ff5392e6e43a3fa78208425242e1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f16_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "normalization_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Pass = ck::tensor_operation::element_wise::PassThrough; + +void add_device_normalization_rank_5_3_f16_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_normalization_f16_generic_instance{}); + add_device_operation_instances(instances, device_normalization_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d85817aad3194f7c225d7d45da601c26a1046196 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f32_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "normalization_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Pass = ck::tensor_operation::element_wise::PassThrough; + +void add_device_normalization_rank_5_3_f32_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_normalization_f32_generic_instance{}); + add_device_operation_instances(instances, device_normalization_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_f32_f32_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_f32_f32_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a81f776c0f36008f9acf72c6328e8d6cd59ee4d6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_f32_f32_f16_instance.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "normalization_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Swish = ck::tensor_operation::element_wise::Swish; + +void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_normalization_f16_f32_f32_f16_generic_instance{}); + add_device_operation_instances(instances, + device_normalization_f16_f32_f32_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f4bb8bda8145f5e642f7028f7d9271e611038809 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "normalization_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Swish = ck::tensor_operation::element_wise::Swish; + +void add_device_normalization_rank_5_3_swish_f16_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_normalization_f16_generic_instance{}); + add_device_operation_instances(instances, device_normalization_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bbb9bd0fe8b83d3ed01489e5ebabc4708bba8a71 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f32_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "normalization_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Swish = ck::tensor_operation::element_wise::Swish; + +void add_device_normalization_rank_5_3_swish_f32_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_normalization_f32_generic_instance{}); + add_device_operation_instances(instances, device_normalization_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3f7e4aff1a2987de40568fce9fed880838302a39 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f16_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "normalization_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Pass = ck::tensor_operation::element_wise::PassThrough; + +void add_device_normalization_rank_2_1_f16_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_normalization_f16_generic_instance{}); + add_device_operation_instances(instances, device_normalization_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1f0db3a03666c03522559c4322bfe172206d6804 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f32_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "normalization_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Pass = ck::tensor_operation::element_wise::PassThrough; + +void add_device_normalization_rank_2_1_f32_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_normalization_f32_generic_instance{}); + add_device_operation_instances(instances, device_normalization_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cb9d72e61423666be1fe0f2d2ab3e85d86832556 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f16_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "normalization_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Pass = ck::tensor_operation::element_wise::PassThrough; + +void add_device_normalization_rank_4_3_f16_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_normalization_f16_generic_instance{}); + add_device_operation_instances(instances, device_normalization_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ed555b840da42bf6693f2c93cdce5f8f7dceb246 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f32_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "normalization_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Pass = ck::tensor_operation::element_wise::PassThrough; + +void add_device_normalization_rank_4_3_f32_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_normalization_f32_generic_instance{}); + add_device_operation_instances(instances, device_normalization_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f16_instance.cpp deleted file mode 100644 index beeaa3aa22d0d80619eb06114a6aabb0fb09f404..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f16_instance.cpp +++ /dev/null @@ -1,70 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" -#include "ck/utility/data_type.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F16 = ck::half_t; -using F32 = float; - -using Pass = ck::tensor_operation::element_wise::PassThrough; - -template -// clang-format off -using device_normalization_f16_instances = - std::tuple < - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize> - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl - >; -// clang-format on - -void add_device_normalization_rank_2_1_f16_instances( - std::vector>>& - instances) -{ - add_device_operation_instances(instances, device_normalization_f16_instances{}); -} - -void add_device_normalization_rank_4_3_f16_instances( - std::vector>>& - instances) -{ - add_device_operation_instances(instances, device_normalization_f16_instances{}); -} - -void add_device_normalization_rank_5_3_f16_instances( - std::vector>>& - instances) -{ - add_device_operation_instances(instances, device_normalization_f16_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f32_instance.cpp deleted file mode 100644 index 4d236fb633272f6391e8b393147edcf9c3e84734..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f32_instance.cpp +++ /dev/null @@ -1,69 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" -#include "ck/utility/data_type.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F32 = float; - -using Pass = ck::tensor_operation::element_wise::PassThrough; - -template -using device_layernorm_f32_instances = std::tuple< - // clang-format off - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl - // clang-format on - >; - -void add_device_normalization_rank_2_1_f32_instances( - std::vector>>& - instances) -{ - add_device_operation_instances(instances, device_layernorm_f32_instances{}); -} - -void add_device_normalization_rank_4_3_f32_instances( - std::vector>>& - instances) -{ - add_device_operation_instances(instances, device_layernorm_f32_instances{}); -} - -void add_device_normalization_rank_5_3_f32_instances( - std::vector>>& - instances) -{ - add_device_operation_instances(instances, device_layernorm_f32_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization/normalization_instance_common.hpp b/library/src/tensor_operation_instance/gpu/normalization/normalization_instance_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b0684962f9ef649d1f90a83d40fb2b76c085c549 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization/normalization_instance_common.hpp @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using device_normalization_f16_instances = + // clang-format off + std::tuple < + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize> + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl + // clang-format on + >; + +template +using device_normalization_f16_generic_instance = std::tuple< + // clang-format off + DeviceNormalizationImpl + // clang-format on + >; + +template +using device_normalization_f32_instances = std::tuple< + // clang-format off + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl + // clang-format on + >; + +template +using device_normalization_f32_generic_instance = std::tuple< + // clang-format off + DeviceNormalizationImpl + // clang-format on + >; + +template +using device_normalization_f16_f32_f32_f16_instances = std::tuple< + // clang-format off + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl + // clang-format on + >; + +template +using device_normalization_f16_f32_f32_f16_generic_instance = std::tuple< + // clang-format off + DeviceNormalizationImpl + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/pool_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/pool_fwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..0d0f896c8d932384fbe8e3e84a7f1d0e6b2c388d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/pool_fwd/CMakeLists.txt @@ -0,0 +1,10 @@ +add_instance_library(device_pool_fwd_instance + device_avg_pool2d_fwd_nhwc_f16_instance.cpp + device_avg_pool2d_fwd_nhwc_f32_instance.cpp + device_avg_pool3d_fwd_ndhwc_f16_instance.cpp + device_avg_pool3d_fwd_ndhwc_f32_instance.cpp + device_max_pool2d_fwd_nhwc_f16_instance.cpp + device_max_pool2d_fwd_nhwc_f32_instance.cpp + device_max_pool3d_fwd_ndhwc_f16_instance.cpp + device_max_pool3d_fwd_ndhwc_f32_instance.cpp +) diff --git a/library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool2d_fwd_nhwc_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool2d_fwd_nhwc_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..508ad3873b65e0cf7109adb2ffc1f7d7666149eb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool2d_fwd_nhwc_f16_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "pool_fwd_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; + +void add_device_pool2d_fwd_nhwc_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_pool2d_fwd_nhwc_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool2d_fwd_nhwc_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool2d_fwd_nhwc_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ada96a93a2dfe1c71e83d3f14fdf2bb4b4647da0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool2d_fwd_nhwc_f32_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "pool_fwd_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; + +void add_device_pool2d_fwd_nhwc_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_pool2d_fwd_nhwc_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool3d_fwd_ndhwc_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool3d_fwd_ndhwc_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..62bcad992ae1b8d77bfea25f673191979e1aeab8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool3d_fwd_ndhwc_f16_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "pool_fwd_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; + +void add_device_pool3d_fwd_ndhwc_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_pool3d_fwd_ndhwc_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool3d_fwd_ndhwc_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool3d_fwd_ndhwc_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..47896be911571933136b04e726d7273eb635254c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool3d_fwd_ndhwc_f32_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "pool_fwd_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; + +void add_device_pool3d_fwd_ndhwc_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_pool3d_fwd_ndhwc_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool2d_fwd_nhwc_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool2d_fwd_nhwc_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..35c8522d9f9aedb64c15338a3e8389102534262b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool2d_fwd_nhwc_f16_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "pool_fwd_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; + +void add_device_pool2d_fwd_nhwc_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_pool2d_fwd_nhwc_instances{}); +} + +void add_device_pool2d_fwd_nhwc_index_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_pool2d_fwd_nhwc_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool2d_fwd_nhwc_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool2d_fwd_nhwc_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..75b7629f246c70c99bc8f021bc299142609bbd3d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool2d_fwd_nhwc_f32_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "pool_fwd_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; + +void add_device_pool2d_fwd_nhwc_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_pool2d_fwd_nhwc_instances{}); +} + +void add_device_pool2d_fwd_nhwc_index_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_pool2d_fwd_nhwc_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool3d_fwd_ndhwc_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool3d_fwd_ndhwc_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dbfc4acfdfbed1eb1c22a410e4d54edc915c39fe --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool3d_fwd_ndhwc_f16_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "pool_fwd_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; + +void add_device_pool3d_fwd_ndhwc_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_pool3d_fwd_ndhwc_instances{}); +} + +void add_device_pool3d_fwd_ndhwc_index_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_pool3d_fwd_ndhwc_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool3d_fwd_ndhwc_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool3d_fwd_ndhwc_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..63b3e8df8ea3254ee7f5fc0cb4be2270a144a72f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool3d_fwd_ndhwc_f32_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "pool_fwd_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; + +void add_device_pool3d_fwd_ndhwc_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_pool3d_fwd_ndhwc_instances{}); +} + +void add_device_pool3d_fwd_ndhwc_index_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_pool3d_fwd_ndhwc_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/pool_fwd/pool_fwd_instance_common.hpp b/library/src/tensor_operation_instance/gpu/pool_fwd/pool_fwd_instance_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8aa707885b2b785974ecef5491cae9ac7ec391d3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/pool_fwd/pool_fwd_instance_common.hpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I32 = int32_t; +using F16 = ck::half_t; +using F32 = float; + +template +using device_pool2d_fwd_nhwc_instances = + // clang-format off + std::tuple < + DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C, + DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C, + DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C + // clang-format on + >; + +template +using device_pool3d_fwd_ndhwc_instances = + // clang-format off + std::tuple < + DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C, + DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C, + DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp index 7729e42638050507e4228aaa31dd12aafd49b95a..711314985a7cd2ba708138df597e37d1e874709f 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -19,12 +19,13 @@ using Empty_Tuple = ck::Tuple<>; template using S = ck::Sequence; -using GNHWC = ck::tensor_layout::convolution::GNHWC; +using NHWGC = ck::tensor_layout::convolution::NHWGC; using GKYXC = ck::tensor_layout::convolution::GKYXC; -using GNHWK = ck::tensor_layout::convolution::GNHWK; +using NHWGK = ck::tensor_layout::convolution::NHWGK; using GK = ck::tensor_layout::convolution::G_K; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Relu = ck::tensor_operation::element_wise::Relu; +using TanH = ck::tensor_operation::element_wise::TanH; using GK_Tuple = ck::Tuple; using GK_GK_Tuple = ck::Tuple; @@ -32,17 +33,25 @@ using I32_Tuple = ck::Tuple; using F32_Tuple = ck::Tuple; using I32_F32_Tuple = ck::Tuple; +// perlayer using Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; using Relu_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; +// bias + perlayer using Add_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; using Add_Relu_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; +using Add_Mul_TanH_Mul_Clamp = + ck::tensor_operation::element_wise::Add_Mul_Activation_Mul_Clamp; +// perchannel using Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp; using Relu_Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp; +// bias + perchannel using Add_Mul2_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp; using Add_Relu_Mul2_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp; +using Add_Mul2_TanH_Mul_Clamp = + ck::tensor_operation::element_wise::Add_Mul2_Activation_Mul_Clamp; static constexpr ck::index_t NDimSpatial = 2; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp index ba2451101efecb6af81ee7fe10a8bb2f78b66028..39c4f82fefd4beb2a0dc3629f0757f7f01037495 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_conv2d_dl_int8_instance.hpp" @@ -9,10 +9,10 @@ namespace device { namespace instance { void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances( std::vector{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); } + +void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances( + std::vector>>& instances) +{ + // dl + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); +} + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp index ea1c953bb2b9f805ca3816275db597db88e10c47..92e73eb2ee45895c4277eb46f2a5ec94116fc1e2 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_conv2d_dl_int8_instance.hpp" @@ -9,10 +9,10 @@ namespace device { namespace instance { void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances( std::vector>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); } + +void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); +} + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_int8_instance.hpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_int8_instance.hpp index 3c4987f159b5388c3f0558a8028657658054cb65..bb7a570cda26fbe65e4228a1fa1d9b9fbd4d896c 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_int8_instance.hpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_int8_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -12,7 +12,10 @@ namespace device { namespace instance { // clang-format off -template , S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, DstScalarPerVector> + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< NDimSpatial, int8_t, int8_t, DsDatatype, int8_t, int32_t, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, DstScalarPerVector> >; // clang-format on diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp index d45c1c1ee776719ecce16dca40a480797992f78f..1d8b58fd18a49b777158c266459925a8ef67079b 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_conv2d_dl_int8_instance.hpp" @@ -9,10 +9,10 @@ namespace device { namespace instance { void add_device_conv2d_dl_perchannel_quantization_int8_instances( std::vector>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); } + +void add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); +} + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp index d598d3d38e7c196ddb27948226fa3899666221cc..50ccc69f458fc9a46bbe5e08510e3da1174a90cc 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_conv2d_xdl_int8_instance.hpp" @@ -9,10 +9,10 @@ namespace device { namespace instance { void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances( std::vector>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); } + +void add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); +} + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_int8_instance.hpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_int8_instance.hpp index 262ec06b7fbe3e8c2aa517053478065b2cbbcd48..caced6c950a7872714b76aadee100e3a524a03d8 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_int8_instance.hpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_int8_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -12,30 +12,33 @@ namespace device { namespace instance { // clang-format off -template using device_grouped_conv2d_xdl_int8_instances = std::tuple < - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, DstScalarPerVector>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, DstScalarPerVector>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, DstScalarPerVector>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, DstScalarPerVector>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, DstScalarPerVector>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, DstScalarPerVector>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, DstScalarPerVector>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, DstScalarPerVector> + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, DstScalarPerVector> >; // clang-format on diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp index b0f86b5c2e3b834e6bc88b79f3a8aba4b9f37d90..526fe73463cba7ac977b597c2279488aaa9afe6b 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_conv2d_xdl_int8_instance.hpp" @@ -9,10 +9,10 @@ namespace device { namespace instance { void add_device_conv2d_xdl_perchannel_quantization_int8_instances( std::vector>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances - -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.hpp" - -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_f16_f16_rank3_instances( - std::vector>& instances) -{ - add_device_softmax_f16_f16_rank3_reduce1_instances(instances); - add_device_softmax_f16_f16_rank3_reduce2_instances(instances); - add_device_softmax_f16_f16_rank3_reduce3_instances(instances); -} - -void add_device_softmax_f16_f16_rank4_instances( - std::vector>& instances) -{ - add_device_softmax_f16_f16_rank4_reduce1_instances(instances); - add_device_softmax_f16_f16_rank4_reduce2_instances(instances); - add_device_softmax_f16_f16_rank4_reduce3_instances(instances); - add_device_softmax_f16_f16_rank4_reduce4_instances(instances); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.cpp index fa334b997c27af20f05c42335bde28f5f32f819a..36867d993f9163bd7bdb8e81eb0ed51920ad2458 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 3; - void add_device_softmax_f16_f16_rank3_reduce1_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f16_f16_instances{}); + add_device_operation_instances(instances, device_softmax_f16_f16_generic_instance<3, 1>{}); + add_device_operation_instances(instances, device_softmax_f16_f16_instances<3, 1>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.cpp index 1c9d37d8483b7f249a0b1c05551c70acdb46db03..373f33ad59716bcc2ddeb628a92dd0d46f7ed0aa 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 3; - void add_device_softmax_f16_f16_rank3_reduce2_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f16_f16_instances{}); + add_device_operation_instances(instances, device_softmax_f16_f16_generic_instance<3, 2>{}); + add_device_operation_instances(instances, device_softmax_f16_f16_instances<3, 2>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.cpp index 5fbdab5055edb0f124bbc6541fcc705a12b71d63..d26b92b4f4998919d046401fb5f813594def112c 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 3; - void add_device_softmax_f16_f16_rank3_reduce3_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f16_f16_instances{}); + add_device_operation_instances(instances, device_softmax_f16_f16_generic_instance<3, 3>{}); + add_device_operation_instances(instances, device_softmax_f16_f16_instances<3, 3>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.cpp index 7dd8640b187a2b793d5ad6931ab329ae024ca432..bbb735b6fe564e586cc269eb030d2dc38d4f7cd1 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 4; - void add_device_softmax_f16_f16_rank4_reduce1_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f16_f16_instances{}); + add_device_operation_instances(instances, device_softmax_f16_f16_generic_instance<4, 1>{}); + add_device_operation_instances(instances, device_softmax_f16_f16_instances<4, 1>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.cpp index b32fe6838f83440e98f524c59af9a383760dfe7d..92dbe6776039a9d8e3fd2d358abac54d33fe4537 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 4; - void add_device_softmax_f16_f16_rank4_reduce2_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f16_f16_instances{}); + add_device_operation_instances(instances, device_softmax_f16_f16_generic_instance<4, 2>{}); + add_device_operation_instances(instances, device_softmax_f16_f16_instances<4, 2>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.cpp index c05048ec567bd6471fca14c41d062eaf8c0c86a5..354cda85d757cb04d2b803ffec32e3a6c36f1856 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 4; - void add_device_softmax_f16_f16_rank4_reduce3_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f16_f16_instances{}); + add_device_operation_instances(instances, device_softmax_f16_f16_generic_instance<4, 3>{}); + add_device_operation_instances(instances, device_softmax_f16_f16_instances<4, 3>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.cpp index 6a235708bd426c7270f16246fa058eae45cdd056..edb5e42c103badf9fd65f482577db9ccc2ee08ab 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 4; - void add_device_softmax_f16_f16_rank4_reduce4_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f16_f16_instances{}); + add_device_operation_instances(instances, device_softmax_f16_f16_generic_instance<4, 4>{}); + add_device_operation_instances(instances, device_softmax_f16_f16_instances<4, 4>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance.cpp deleted file mode 100644 index e5bec5e2639d234b3944f679e99e988f20b2c383..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance.cpp +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.hpp" - -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_f32_f32_rank3_instances( - std::vector>& instances) -{ - add_device_softmax_f32_f32_rank3_reduce1_instances(instances); - add_device_softmax_f32_f32_rank3_reduce2_instances(instances); - add_device_softmax_f32_f32_rank3_reduce3_instances(instances); -} - -void add_device_softmax_f32_f32_rank4_instances( - std::vector>& instances) -{ - add_device_softmax_f32_f32_rank4_reduce1_instances(instances); - add_device_softmax_f32_f32_rank4_reduce2_instances(instances); - add_device_softmax_f32_f32_rank4_reduce3_instances(instances); - add_device_softmax_f32_f32_rank4_reduce4_instances(instances); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.cpp index 57d3f184a6635f4febe632054c9e39798ef0cc82..566be8fc22c4a5b8a5a6bb17ddebf012bf39d1d4 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 3; - void add_device_softmax_f32_f32_rank3_reduce1_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f32_f32_instances{}); + add_device_operation_instances(instances, device_softmax_f32_f32_generic_instance<3, 1>{}); + add_device_operation_instances(instances, device_softmax_f32_f32_instances<3, 1>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.cpp index fae3a4dd6662d8920b1b88c8489e9a6f00f28ed0..f9c76e3116cd412c4efaca27595c5ebedf205943 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 3; - void add_device_softmax_f32_f32_rank3_reduce2_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f32_f32_instances{}); + add_device_operation_instances(instances, device_softmax_f32_f32_generic_instance<3, 2>{}); + add_device_operation_instances(instances, device_softmax_f32_f32_instances<3, 2>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.cpp index b6fb70e8e2a6aa097f2a76865f690812dffaffdd..541e0d71a939c5964e560d76771986bdbc1cfc03 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 3; - void add_device_softmax_f32_f32_rank3_reduce3_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f32_f32_instances{}); + add_device_operation_instances(instances, device_softmax_f32_f32_generic_instance<3, 3>{}); + add_device_operation_instances(instances, device_softmax_f32_f32_instances<3, 3>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.cpp index 33c7b6f35f351f755a50f48f92d391fd47f40163..95a38df2834b1794384bfdab57087c051a8fc482 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 4; - void add_device_softmax_f32_f32_rank4_reduce1_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f32_f32_instances{}); + add_device_operation_instances(instances, device_softmax_f32_f32_generic_instance<4, 1>{}); + add_device_operation_instances(instances, device_softmax_f32_f32_instances<4, 1>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.cpp index c22aa574b1f984d5a4f8c3276c9fb11b74002680..a29b88891d45b7f967629d956b69bc25c371ef88 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 4; - void add_device_softmax_f32_f32_rank4_reduce2_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f32_f32_instances{}); + add_device_operation_instances(instances, device_softmax_f32_f32_generic_instance<4, 2>{}); + add_device_operation_instances(instances, device_softmax_f32_f32_instances<4, 2>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.cpp index 55f3d2bd207a0a160f00415592f0106f41be5cee..0da46ea1b47489e9eb9fff991c7c5730e4b5338c 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 4; - void add_device_softmax_f32_f32_rank4_reduce3_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f32_f32_instances{}); + add_device_operation_instances(instances, device_softmax_f32_f32_generic_instance<4, 3>{}); + add_device_operation_instances(instances, device_softmax_f32_f32_instances<4, 3>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.cpp index fb0bcf5ee8a2dd6fd5ed4f77c0d8f4be5c5fcacd..fa217dc3f5b341112645c95df44826f19d2105c4 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 4; - void add_device_softmax_f32_f32_rank4_reduce4_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f32_f32_instances{}); + add_device_operation_instances(instances, device_softmax_f32_f32_generic_instance<4, 4>{}); + add_device_operation_instances(instances, device_softmax_f32_f32_instances<4, 4>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance.cpp deleted file mode 100644 index 608cfcf8380be29e24d5e4b3aadd7573285a5224..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance.cpp +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce1.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce2.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce3.hpp" - -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce1.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce2.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce3.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce4.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_i8_i8_rank3_instances( - std::vector>& instances) -{ - add_device_softmax_i8_i8_rank3_reduce1_instances(instances); - add_device_softmax_i8_i8_rank3_reduce2_instances(instances); - add_device_softmax_i8_i8_rank3_reduce3_instances(instances); -} - -void add_device_softmax_i8_i8_rank4_instances( - std::vector>& instances) -{ - add_device_softmax_i8_i8_rank4_reduce1_instances(instances); - add_device_softmax_i8_i8_rank4_reduce2_instances(instances); - add_device_softmax_i8_i8_rank4_reduce3_instances(instances); - add_device_softmax_i8_i8_rank4_reduce4_instances(instances); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce1.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce1.cpp deleted file mode 100644 index 15552dbae5d501c506f4fd14b439a80664a95b66..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce1.cpp +++ /dev/null @@ -1,27 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce1.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -static constexpr index_t RANK = 3; - -void add_device_softmax_i8_i8_rank3_reduce1_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, device_softmax_i8_i8_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce2.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce2.cpp deleted file mode 100644 index 67674028860b471888bc87a59e1b7ae751c64eec..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce2.cpp +++ /dev/null @@ -1,27 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce2.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -static constexpr index_t RANK = 3; - -void add_device_softmax_i8_i8_rank3_reduce2_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, device_softmax_i8_i8_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce3.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce3.cpp deleted file mode 100644 index 4b33da93c2e1330485a55f8ede6654ab2a03fbb1..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce3.cpp +++ /dev/null @@ -1,27 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce3.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -static constexpr index_t RANK = 3; - -void add_device_softmax_i8_i8_rank3_reduce3_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, device_softmax_i8_i8_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce1.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce1.cpp deleted file mode 100644 index fe3b823e889267f69001e15066bbef201afa9813..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce1.cpp +++ /dev/null @@ -1,27 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce1.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -static constexpr index_t RANK = 4; - -void add_device_softmax_i8_i8_rank4_reduce1_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, device_softmax_i8_i8_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce2.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce2.cpp deleted file mode 100644 index 8ecdf87d9fec061094f83c3beba01497dcb8e5b8..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce2.cpp +++ /dev/null @@ -1,27 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce2.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -static constexpr index_t RANK = 4; - -void add_device_softmax_i8_i8_rank4_reduce2_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, device_softmax_i8_i8_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce3.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce3.cpp deleted file mode 100644 index 3563135204085ff8e44d1378f9c4d9ffd99e7b68..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce3.cpp +++ /dev/null @@ -1,27 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce3.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -static constexpr index_t RANK = 4; - -void add_device_softmax_i8_i8_rank4_reduce3_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, device_softmax_i8_i8_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce4.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce4.cpp deleted file mode 100644 index aa21a0bf8a863f433665e01e8229bd179948e1db..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce4.cpp +++ /dev/null @@ -1,27 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce4.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -static constexpr index_t RANK = 4; - -void add_device_softmax_i8_i8_rank4_reduce4_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, device_softmax_i8_i8_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/utility/convolution_parameter.cpp b/library/src/utility/convolution_parameter.cpp index c8712d20939d5f9f1b7b981645e77420c7a5c607..57cedd60199e1c3aa291d44c404d56e89f444e9a 100644 --- a/library/src/utility/convolution_parameter.cpp +++ b/library/src/utility/convolution_parameter.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/host_utility/io.hpp" diff --git a/library/src/utility/device_memory.cpp b/library/src/utility/device_memory.cpp index 90f943313b0961bb96d6394e855809c62d050559..11166783e8e542f565bfc1171ae40d879a005268 100644 --- a/library/src/utility/device_memory.cpp +++ b/library/src/utility/device_memory.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/host_utility/hip_check_error.hpp" diff --git a/library/src/utility/host_tensor.cpp b/library/src/utility/host_tensor.cpp index e34fbc8f345b8ca6dd5b34988a1839c6c59e61bd..7211552641195d68e42ff606ddc1b2c7aab7ba84 100644 --- a/library/src/utility/host_tensor.cpp +++ b/library/src/utility/host_tensor.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/profiler/README.md b/profiler/README.md index bfd6a3a53be6ad8c84bb35c97d09f6125145591f..3ee039cb3f5d13be0ebdfc5ea41c5180a60cdbd8 100644 --- a/profiler/README.md +++ b/profiler/README.md @@ -46,3 +46,98 @@ out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} .... Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s ``` + +## Profile contraction kernels +```bash +#arg1: tensor operation (contraction_bilinear=CONTRACTION+Bilinear) +#arg2: data type (0: fp32; 1: f64)\n" +#arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]; +# 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]; +# 2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]; +# 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]) +#arg4: verification (0: no; 1: yes) +#arg5: initialization (0: no init; 1: integer value; 2: decimal value) +#arg6: print tensor value (0: no; 1: yes) +#arg7: time kernel (0: no, 1: yes) +#arg8 and arg9: alpha and beta +#arg10 to 15: M0, M1, N0, N1, K0, K1 +#arg16 to 31: Strides for A, B, D and E (skip for default) + +################ op datatype layout verify init log time alpha beta M0 M1 N0 N1 K0 K1 +./bin/ckProfiler contraction_bilinear 0 1 0 0 0 1 1.0 1.0 128 128 128 128 128 128 +``` + +Result (MI100) +```bash +a_m_k: dim 4, lengths {128, 128, 128, 128}, strides {2097152, 16384, 128, 1} +b_k_n: dim 4, lengths {128, 128, 128, 128}, strides {128, 1, 2097152, 16384} +d_m_n: dim 4, lengths {128, 128, 128, 128}, strides {2097152, 16384, 128, 1} +e_m_n: dim 4, lengths {128, 128, 128, 128}, strides {2097152, 16384, 128, 1} +.... +Best Perf: 211.405 ms, 41.6077 TFlops, 15.2372 GB/s +``` + +## Profile batched gemm multiple D kernels +```bash +#arg1: tensor operation (batched_gemm_multi_d=Batched GEMM multi D); +#arg2: data type (0: fp16; 1: int8) +#arg3: matrix layout (0: A[g, m, k] * B[g, k, n] = C[g, m, n]; +# 1: A[g, m, k] * B[g, n, k] = C[g, m, n]; +# 2: A[g, k, m] * B[g, k, n] = C[g, m, n]; +# 3: A[g, k, m] * B[g, n, k] = C[g, m, n]) +#arg4: verification (0: no; 1: yes) +#arg5: initialization (0: no init; 1: integer value; 2: decimal value) +#arg6: print tensor value (0: no; 1: yes) +#arg7: time kernel (0=n0, 1=yes) +#arg8 to 17: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount + +################ op datatype layout verify init log time M N K StrideA StrideB StrideC BatchStrideA BatchStrideB BatchStrideC BatchCount +./bin/ckProfiler batched_gemm_multi_d 0 1 0 0 0 1 4096 4096 4096 4096 4096 4096 16777216 16777216 16777216 16 +``` + +Result (Radeon RX 6800 XT) +```bash +arg.a_grid_desc_k0_m0_m1_k1_{2048, 4096, 2} +arg.b_grid_desc_k0_n0_n1_k1_{2048, 4096, 2} +arg.e_grid_desc_m_n_{ 4096, 4096} +.... +Best Perf: 58.0306 ms, 37.8942 TFlops, 27.7545 GB/s +## Profile grouped convolution backward data kernels +```bash +# arg1: tensor operation (grouped_conv_bwd_data: Grouped Convolution Backward Data) +# arg2: data type (0: Output fp32, Weight fp32, Input fp32 +# 1: Output fp16, Weight fp16, Input fp16 +# 2: Output bf16, Weight bf16, Input bf16 +# arg3: tensor layout (0: Output[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Input[G, N, Ho, Wo, K] +# 1: Output[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Input[N, Ho, Wo, G, K]) +# arg4: verification (0: no, 1: yes) +# arg5: initialization (0: no init, 1: integer value, 2: decimal value) +# arg6: print tensor value (0: no; 1: yes) +# arg7: time kernel (0: no, 1: yes) +# Following arguments (depending on number of spatial dims): +# Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d) +# G, N, K, C, +# , (ie Y, X for 2D) +# , (ie Hi, Wi for 2D) +# , (ie Sy, Sx for 2D) +# , (ie Dy, Dx for 2D) +# , (ie LeftPy, LeftPx for 2D) +# , (ie RightPy, RightPx for 2D) + + ################ op datatype layout verify init log time Ndims G N K C Y X Hi Wi Sy Sx Dy Dx LeftPy LeftPx RightPy RightPx +./bin/ckProfiler grouped_conv_bwd_data 1 0 1 1 0 1 2 32 4 192 192 3 3 28 28 1 1 1 1 1 1 1 1 + + ``` + +Result (MI100, FP16, GNHWC_GKYXC_GNHWK) +``` +out: dim 5, lengths {32, 4, 192, 28, 28}, strides {602112, 150528, 1, 5376, 192} +wei: dim 5, lengths {32, 192, 192, 3, 3}, strides {331776, 1728, 1, 576, 192} +in: dim 5, lengths {32, 4, 192, 28, 28}, strides {602112, 150528, 1, 5376, 192} +.... +Best configuration parameters: +name: DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<256, 128, 256, 32, 8, 2, Default, 32, 32, 2, 4, 8, 4, 1, 1> +avg_time: 0.768321 +tflops: 86.6679 +GB/s: 127.947 +``` diff --git a/profiler/include/profiler/data_type_enum.hpp b/profiler/include/profiler/data_type_enum.hpp index afcd6fea224f3e06f29309349bcc6ad55caa7d4f..c046c7fabb30e831fe41b31f936febef0f5bf148 100644 --- a/profiler/include/profiler/data_type_enum.hpp +++ b/profiler/include/profiler/data_type_enum.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/data_type_enum_helper.hpp b/profiler/include/profiler/data_type_enum_helper.hpp deleted file mode 100644 index d9bd5e1a4008fc5e6c09815c45408d8b07376a4f..0000000000000000000000000000000000000000 --- a/profiler/include/profiler/data_type_enum_helper.hpp +++ /dev/null @@ -1,77 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma - -#include "ck/utility/data_type.hpp" -#include "profiler/data_type_enum.hpp" - -namespace ck { - -template -struct get_datatype_from_enum; - -template <> -struct get_datatype_from_enum -{ - using type = int8_t; -}; - -template <> -struct get_datatype_from_enum -{ - using type = int32_t; -}; - -template <> -struct get_datatype_from_enum -{ - using type = half_t; -}; - -template <> -struct get_datatype_from_enum -{ - using type = float; -}; - -template <> -struct get_datatype_from_enum -{ - using type = double; -}; - -template -struct get_datatype_enum_from_type; - -template <> -struct get_datatype_enum_from_type -{ - static constexpr DataTypeEnum value = DataTypeEnum::Int8; -}; - -template <> -struct get_datatype_enum_from_type -{ - static constexpr DataTypeEnum value = DataTypeEnum::Int32; -}; - -template <> -struct get_datatype_enum_from_type -{ - static constexpr DataTypeEnum value = DataTypeEnum::Half; -}; - -template <> -struct get_datatype_enum_from_type -{ - static constexpr DataTypeEnum value = DataTypeEnum::Float; -}; - -template <> -struct get_datatype_enum_from_type -{ - static constexpr DataTypeEnum value = DataTypeEnum::Double; -}; - -} // namespace ck diff --git a/profiler/include/profiler/profile_batched_gemm_add_relu_gemm_add_impl.hpp b/profiler/include/profiler/profile_batched_gemm_add_relu_gemm_add_impl.hpp index b16254279ce433d06f71ca350b53841a294a3de6..22dab31100d469938d39de0db89bf5b2257383c8 100644 --- a/profiler/include/profiler/profile_batched_gemm_add_relu_gemm_add_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_add_relu_gemm_add_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp b/profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp index 799dccc0ff3a8ef253a6153dbdc64d8f7a5c26cb..5bee67c1ce9dc619f2b5208bf350a8fc01dceb06 100644 --- a/profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp b/profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp index 1583c6db21e0959bd155946b2e81363091b4f43f..f3d2c5561756d43a4af44bbde12cf93c26f5602c 100644 --- a/profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_batched_gemm_impl.hpp b/profiler/include/profiler/profile_batched_gemm_impl.hpp index c07d7c0555490559bbf0357a5d2657bcbe5261fd..936c22f5d89143198a1d4a01cb06bd5c66ca043e 100644 --- a/profiler/include/profiler/profile_batched_gemm_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -8,9 +8,11 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp" +#include "ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -27,7 +29,11 @@ template + typename CLayout, + typename AElementOp, + typename BElementOp, + typename CElementOp, + typename DeviceOp> bool profile_batched_gemm_impl(int do_verification, int init_method, bool do_log, @@ -88,10 +94,6 @@ bool profile_batched_gemm_impl(int do_verification, b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); } - using AElementOp = ck::tensor_operation::element_wise::PassThrough; - using BElementOp = ck::tensor_operation::element_wise::PassThrough; - using CElementOp = ck::tensor_operation::element_wise::PassThrough; - const auto a_element_op = AElementOp{}; const auto b_element_op = BElementOp{}; const auto c_element_op = CElementOp{}; @@ -124,16 +126,6 @@ bool profile_batched_gemm_impl(int do_verification, b_device_buf.ToDevice(b_g_k_n.mData.data()); c_device_buf.ToDevice(c_g_m_n_device_result.mData.data()); - using DeviceOp = ck::tensor_operation::device::DeviceBatchedGemm; - // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); @@ -148,23 +140,62 @@ bool profile_batched_gemm_impl(int do_verification, // profile device op instances for(auto& op_ptr : op_ptrs) { - auto argument_ptr = - op_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC, - BatchStrideA, - BatchStrideB, - BatchStrideC, - BatchCount, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}); + std::unique_ptr argument_ptr; + // false branch for multi d dl kernel + if constexpr(std::is_same< + DeviceOp, + ck::tensor_operation::device::DeviceBatchedGemm>::value) + { + + argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideC, + BatchCount, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}); + } + else + { + argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + {}, + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + BatchCount, + StrideA, + StrideB, + {}, + StrideC, + BatchStrideA, + BatchStrideB, + {}, + BatchStrideC, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}); + } auto invoker_ptr = op_ptr->MakeInvokerPointer(); diff --git a/profiler/include/profiler/profile_batched_gemm_reduce_impl.hpp b/profiler/include/profiler/profile_batched_gemm_reduce_impl.hpp index 45b7b77388b63c8abba43aabc50a9c0dd86eeede..901fa338d4d7aab26a70da548e2b000c98840d26 100644 --- a/profiler/include/profiler/profile_batched_gemm_reduce_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_reduce_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp b/profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp index f5ec235141a7238157c0d2fd34f1f9c849decfc4..15a21206c5945d700069cd7a8f8a9fdff7c4e9bf 100644 --- a/profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp b/profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp index 91c28f25fc5953460a00d87961b97d1587aad1ce..f2fcb0b133861b87a92e3434c8c3e22b0ef1a51c 100644 --- a/profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_batchnorm_backward_impl.hpp b/profiler/include/profiler/profile_batchnorm_backward_impl.hpp index 79d8862081fd78b0706110219dc5af5016cd9122..3343b5e66e4b8666edf778abf4393dd75cbd0937 100644 --- a/profiler/include/profiler/profile_batchnorm_backward_impl.hpp +++ b/profiler/include/profiler/profile_batchnorm_backward_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_batchnorm_forward_impl.hpp b/profiler/include/profiler/profile_batchnorm_forward_impl.hpp index 82fe75bf015beba4d56898f67aa2100b2d3c6ff4..2f9538b16c326a740b1e3bd0c3c150eaaadb9ebf 100644 --- a/profiler/include/profiler/profile_batchnorm_forward_impl.hpp +++ b/profiler/include/profiler/profile_batchnorm_forward_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_batchnorm_infer_impl.hpp b/profiler/include/profiler/profile_batchnorm_infer_impl.hpp index ca653393452c8b4fb880fee5041a72a497c9566b..1b31a2aabf5f42a115d83270f91d87f0f74d8a4d 100644 --- a/profiler/include/profiler/profile_batchnorm_infer_impl.hpp +++ b/profiler/include/profiler/profile_batchnorm_infer_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_contraction_impl.hpp b/profiler/include/profiler/profile_contraction_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..660cc3f9e5236fdba5ce3c6373189cbc1ed5fbc1 --- /dev/null +++ b/profiler/include/profiler/profile_contraction_impl.hpp @@ -0,0 +1,345 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction_scale.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" + +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace profiler { + +using Bilinear = ck::tensor_operation::element_wise::Bilinear; +using Scale = ck::tensor_operation::element_wise::Scale; + +template +int profile_contraction_impl(ck::index_t do_verification, + ck::index_t init_method, + bool do_log, + bool time_kernel, + CDElementOp cde_element_op, + const std::vector& M, + const std::vector& N, + const std::vector& K, + const std::vector& StridesA, + const std::vector& StridesB, + const std::vector& StridesE, + const std::vector& StridesD) +{ + bool pass = true; + + auto f_host_tensor_descriptor = [](const std::vector& dims01, + const std::vector& dims23, + const std::vector& strides) { + std::vector dims_szt(dims01.begin(), dims01.end()); + dims_szt.insert(dims_szt.end(), dims23.begin(), dims23.end()); + std::vector strides_szt(strides.begin(), strides.end()); + + return HostTensorDescriptor(dims_szt, strides); + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StridesA)); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StridesB)); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE)); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StridesE)); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StridesD)); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + + DeviceMem a_device_buf(sizeof(DataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(DataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(DataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DataType) * d_m_n.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + e_device_buf.SetZero(); + d_device_buf.ToDevice(d_m_n.mData.data()); + + const std::vector a_ms_ks_lengths = {M[0], M[1], K[0], K[1]}; + const std::vector b_ns_ks_lengths = {N[0], N[1], K[0], K[1]}; + const std::vector e_ms_ns_lengths = {M[0], M[1], N[0], N[1]}; + const std::vector d_m_n_lengths = {M[0], M[1], N[0], N[1]}; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + + constexpr ck::index_t NumDim = 2; + using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // Run reference op + if(do_verification) + { + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceGemmInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE)); + + auto ref_argument = + ref_op.MakeArgument(a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op); + + ref_invoker.Run(ref_argument); + + for(size_t m0 = 0; m0 < e_m_n_host_result.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_m_n_host_result.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_m_n_host_result.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_m_n_host_result.mDesc.GetLengths()[3]; ++n1) + { + if constexpr(is_same::value) + { + cde_element_op(e_m_n_host_result(m0, m1, n0, n1), + c_m_n_host_result(m0, m1, n0, n1), + d_m_n(m0, m1, n0, n1)); + } + else if constexpr(is_same::value) + { + cde_element_op(e_m_n_host_result(m0, m1, n0, n1), + c_m_n_host_result(m0, m1, n0, n1)); + } + else + { + static_assert("Unsupported CDElementOp in contraction profiler."); + } + } + } + } + } + } + + std::string best_op_name; + float best_avg_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device op instances + for(auto& op_ptr : op_ptrs) + { + std::unique_ptr argument_ptr; + if constexpr(is_same::value) + { + argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + std::array{d_device_buf.GetDeviceBuffer()}, + static_cast(e_device_buf.GetDeviceBuffer()), + a_ms_ks_lengths, + StridesA, + b_ns_ks_lengths, + StridesB, + std::array, 1>{d_m_n_lengths}, + std::array, 1>{StridesD}, + e_ms_ns_lengths, + StridesE, + a_element_op, + b_element_op, + cde_element_op); + } + else if constexpr(is_same::value) + { + argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + std::array{}, + static_cast(e_device_buf.GetDeviceBuffer()), + a_ms_ks_lengths, + StridesA, + b_ns_ks_lengths, + StridesB, + std::array, 0>{}, + std::array, 0>{}, + e_ms_ns_lengths, + StridesE, + a_element_op, + b_element_op, + cde_element_op); + } + else + { + static_assert("Unsupported CDElementOp in contraction profiler."); + } + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + auto nelems_m = M[0] * M[1]; + auto nelems_n = N[0] * N[1]; + auto nelems_k = K[0] * K[1]; + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init C to zero before profiling next kernel + e_device_buf.SetZero(); + + std::string op_name = op_ptr->GetTypeString(); + + float avg_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * nelems_m * nelems_n * nelems_k; + + std::size_t num_btype = sizeof(DataType) * nelems_m * nelems_k + + sizeof(DataType) * nelems_k * nelems_n + + sizeof(DataType) * nelems_m * nelems_n; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + float threshold = + static_cast(nelems_k) * std::numeric_limits::epsilon(); + pass = pass & ck::utils::check_err(e_m_n_device_result, + e_m_n_host_result, + "Error: incorrect results!", + threshold, + threshold); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host : ", e_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", e_m_n_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + + if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f32"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f64"; + } + + if constexpr(is_same::value) + { + std::cout << " ALayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " ALayout = ColumnMajor"; + } + + if constexpr(is_same::value) + { + std::cout << " BLayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " BLayout = ColumnMajor"; + } + + if constexpr(is_same::value) + { + std::cout << " CDELayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " CDELayout = ColumnMajor"; + } + + std::cout << " M = " << M << " N = " << N << " K = " << K << " StridesA = " << StridesA + << " StridesB = " << StridesB << " StridesE = " << StridesE << " : " << best_avg_time + << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_contraction_utils.hpp b/profiler/include/profiler/profile_contraction_utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..076bbd4559f6ec1652559030fa81be416385e962 --- /dev/null +++ b/profiler/include/profiler/profile_contraction_utils.hpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using Bilinear = ck::tensor_operation::element_wise::Bilinear; +using Scale = ck::tensor_operation::element_wise::Scale; + +enum struct ContractionMatrixLayout +{ + MK_KN_MN_MN, // 0 + MK_NK_MN_MN, // 1 + KM_KN_MN_MN, // 2 + KM_NK_MN_MN, // 3 +}; + +enum struct ContractionDataType +{ + F32_F32_F32_F32, // 0 + F64_F64_F64_F64, // 1 +}; + +inline void collect_index_params(char* argv[], + std::vector& params, + const ck::index_t from, + const ck::index_t num) +{ + for(ck::index_t p = from; p < from + num; p++) + params.push_back(std::stoi(argv[p])); +} + +// Defualt strides for row-major: {Dim1 * Dim2 * Dim3, Dim2 * Dim3, Dim3, 1} +// Defualt strides for column-major: {Dim1, 1, Dim0 * Dim1 * Dim3, Dim0 * Dim1} +inline void +assign_default_strides(Row, std::vector& strides, std::vector dims) +{ + strides = {dims[1] * dims[2] * dims[3], dims[2] * dims[3], dims[3], 1}; +} + +inline void +assign_default_strides(Col, std::vector& strides, std::vector dims) +{ + strides = {dims[1], 1, dims[0] * dims[1] * dims[3], dims[0] * dims[1]}; +} diff --git a/profiler/include/profiler/profile_conv_bwd_data_impl.hpp b/profiler/include/profiler/profile_conv_bwd_data_impl.hpp index 86d394daf90c79d6f9faf34254901dc8e2bd43e2..52152a90fe5eab2fd80ab156d9080c75b65d9eb0 100644 --- a/profiler/include/profiler/profile_conv_bwd_data_impl.hpp +++ b/profiler/include/profiler/profile_conv_bwd_data_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp b/profiler/include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp index 1aebef8bb2b2b34a640797ab9e5985d5cb09fbea..436fbdbd759195add8fba952235392d31b9613a0 100644 --- a/profiler/include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp +++ b/profiler/include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp b/profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp index 2bac144334eefc2d39ff63e17637b50d977921ec..808c1a1c901293816bf3fb153c8f8a7d2faaaf15 100644 --- a/profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp +++ b/profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_conv_fwd_impl.hpp b/profiler/include/profiler/profile_conv_fwd_impl.hpp index 1f3ba8f00714449e51f180879f3b38d01221cefe..bc2eb257970d1b91ab7fbf1a592d18df0ef3d159 100644 --- a/profiler/include/profiler/profile_conv_fwd_impl.hpp +++ b/profiler/include/profiler/profile_conv_fwd_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_convnd_bwd_data_impl.hpp b/profiler/include/profiler/profile_convnd_bwd_data_impl.hpp deleted file mode 100644 index 1e69ebc8bd173b25311b45305a24d71fcd7324a9..0000000000000000000000000000000000000000 --- a/profiler/include/profiler/profile_convnd_bwd_data_impl.hpp +++ /dev/null @@ -1,486 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/conv_util.hpp" -#include "ck/library/host_tensor/device_memory.hpp" -#include "ck/library/host_tensor/host_tensor.hpp" -#include "ck/library/host_tensor/host_tensor_generator.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" - -using F16 = ck::half_t; -using F32 = float; -using BF16 = ck::bhalf_t; -using INT8 = int8_t; - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using DeviceConvBwdDataNoOpPtr = - DeviceConvBwdDataPtr; -void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances( - std::vector&); -void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances( - std::vector&); -void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances( - std::vector&); -void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances( - std::vector&); - -void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( - std::vector&); -void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances( - std::vector&); -void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( - std::vector&); -void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( - std::vector&); - -void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances( - std::vector&); -void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances( - std::vector&); -void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( - std::vector&); -void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances( - std::vector&); -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck - -namespace ck { -namespace profiler { -using DeviceConvBwdDataNoOpPtr = ck::tensor_operation::device::instance::DeviceConvBwdDataNoOpPtr; - -template -HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector& dims, - int num_dim_spatial = 2) -{ - namespace tl = ck::tensor_layout::convolution; - - switch(num_dim_spatial) - { - case 3: { - return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{}); - } - case 2: { - return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{}); - } - case 1: { - return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{}); - } - default: { - throw std::runtime_error("Unsupported number of spatial dimensions provided!"); - } - } -} -template -HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector& dims, - int num_dim_spatial = 2) -{ - namespace tl = ck::tensor_layout::convolution; - - switch(num_dim_spatial) - { - case 3: { - return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{}); - } - case 2: { - return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{}); - } - case 1: { - return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{}); - } - default: { - throw std::runtime_error("Unsupported number of spatial dimensions provided!"); - } - } -} -template -HostTensorDescriptor get_output_host_ensor_descriptor(const std::vector& dims, - int num_dim_spatial = 2) -{ - namespace tl = ck::tensor_layout::convolution; - - switch(num_dim_spatial) - { - case 3: { - return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{}); - } - case 2: { - return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{}); - } - case 1: { - return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{}); - } - default: { - throw std::runtime_error("Unsupported number of spatial dimensions provided!"); - } - } -} -template -void get_device_conv_bwd_data_op_ptr( - InDataType, WeiDataType, OutDataType, std::vector&, int) -{ - std::cout << "can not find device conv bwd data" << std::endl; - exit(1); -} -template <> -void get_device_conv_bwd_data_op_ptr( - F32, F32, F32, std::vector& conv_ptrs, int num_dim_spatial) -{ - switch(num_dim_spatial) - { - case 1: - ck::tensor_operation::device::instance:: - add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs); - break; - case 2: - ck::tensor_operation::device::instance:: - add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); - break; - case 3: - ck::tensor_operation::device::instance:: - add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs); - break; - default: break; - } -} -template <> -void get_device_conv_bwd_data_op_ptr( - F16, F16, F16, std::vector& conv_ptrs, int num_dim_spatial) -{ - switch(num_dim_spatial) - { - case 1: - ck::tensor_operation::device::instance:: - add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs); - break; - case 2: - ck::tensor_operation::device::instance:: - add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); - break; - case 3: - ck::tensor_operation::device::instance:: - add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs); - break; - default: break; - } -} -template <> -void get_device_conv_bwd_data_op_ptr( - BF16, BF16, BF16, std::vector& conv_ptrs, int num_dim_spatial) -{ - switch(num_dim_spatial) - { - case 1: - ck::tensor_operation::device::instance:: - add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs); - break; - case 2: - ck::tensor_operation::device::instance:: - add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); - break; - case 3: - ck::tensor_operation::device::instance:: - add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs); - break; - default: break; - } -} -template <> -void get_device_conv_bwd_data_op_ptr( - INT8, INT8, INT8, std::vector& conv_ptrs, int num_dim_spatial) -{ - switch(num_dim_spatial) - { - case 1: - ck::tensor_operation::device::instance:: - add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs); - break; - case 2: - ck::tensor_operation::device::instance:: - add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); - break; - case 3: - ck::tensor_operation::device::instance:: - add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs); - break; - default: break; - } -} - -template -static bool check_out(const Tensor& ref, const Tensor& result) -{ - float max_diff = 1e-6; - - for(std::size_t i = 0; i < ref.mData.size(); ++i) - { - float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); - if(max_diff < diff) - { - return false; - } - } - return true; -} -template -void show_data_nhwc_layout(Tensor& nhwc) -{ - std::cout << "["; - for(int n = 0; n < ck::type_convert(nhwc.mDesc.GetLengths()[0]); n++) - { - std::cout << "["; - for(int hi = 0; hi < ck::type_convert(nhwc.mDesc.GetLengths()[2]); hi++) - { - std::cout << "["; - for(int wi = 0; wi < ck::type_convert(nhwc.mDesc.GetLengths()[3]); wi++) - { - std::cout << "["; - for(int c = 0; c < ck::type_convert(nhwc.mDesc.GetLengths()[1]); c++) - { - std::cout << static_cast(nhwc(n, c, hi, wi)) << " "; - } - std::cout << "]"; - } - std::cout << "]"; - } - std::cout << "]"; - } - std::cout << "]"; -} - -template -bool profile_convnd_bwd_data_impl(int do_verification, - int init_method, - bool do_log, - bool time_kernel, - ck::index_t N, - ck::index_t K, - ck::index_t C, - const std::vector& input_spatial_lengths, - const std::vector& filter_spatial_lengths, - const std::vector& output_spatial_lengths, - const std::vector& conv_filter_strides, - const std::vector& conv_filter_dilations, - const std::vector& input_left_pads, - const std::vector& input_right_pads) -{ - using InElementOp = ck::tensor_operation::element_wise::PassThrough; - using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; - using OutElementOp = ck::tensor_operation::element_wise::PassThrough; - - const auto in_element_op = InElementOp{}; - const auto wei_element_op = WeiElementOp{}; - const auto out_element_op = OutElementOp{}; - - std::vector input_dims{static_cast(N), static_cast(C)}; - input_dims.insert( - std::end(input_dims), std::begin(input_spatial_lengths), std::end(input_spatial_lengths)); - - std::vector filter_dims{static_cast(K), static_cast(C)}; - filter_dims.insert(std::end(filter_dims), - std::begin(filter_spatial_lengths), - std::end(filter_spatial_lengths)); - - std::vector output_dims{static_cast(N), static_cast(K)}; - output_dims.insert(std::end(output_dims), - std::begin(output_spatial_lengths), - std::end(output_spatial_lengths)); - - Tensor input_host_result( - get_input_host_tensor_descriptor(input_dims, NDimSpatial)); - Tensor input_device_result( - get_input_host_tensor_descriptor(input_dims, NDimSpatial)); - Tensor weights( - get_filters_host_tensor_descriptor(filter_dims, NDimSpatial)); - Tensor output( - get_output_host_ensor_descriptor(output_dims, NDimSpatial)); - - std::cout << "input: " << input_host_result.mDesc << std::endl; - std::cout << "weights: " << weights.mDesc << std::endl; - std::cout << "output: " << output.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - output.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - weights.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - output.GenerateTensorValue(GeneratorTensor_1{1}); - weights.GenerateTensorValue(GeneratorTensor_1{1}); - } - - DeviceMem in_device_buf(sizeof(InDataType) * input_device_result.mDesc.GetElementSpace()); - DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); - DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace()); - - out_device_buf.ToDevice(output.mData.data()); - wei_device_buf.ToDevice(weights.mData.data()); - - // reset input to zero - in_device_buf.SetZero(); - - if(do_verification) - { - auto RunReference = [&](auto& ref_conv) { - auto ref_invoker = ref_conv.MakeInvoker(); - - auto ref_argument = ref_conv.MakeArgument(input_host_result, - weights, - output, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); - ref_invoker.Run(ref_argument); - }; - - auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData(); - RunReference(ref_conv); - } - - // add device Conv instances - std::vector conv_ptrs; - get_device_conv_bwd_data_op_ptr( - InDataType{}, WeiDataType{}, OutDataType{}, conv_ptrs, NDimSpatial); - - if(conv_ptrs.size() <= 0) - { - throw std::runtime_error("wrong! no device Conv instance found"); - } - - std::string best_conv_name; - float best_ave_time = 0; - float best_tflops = 0; - float best_gb_per_sec = 0; - - // profile device Conv instances - bool success = true; - for(auto& conv_ptr : conv_ptrs) - { - auto argument_ptr = conv_ptr->MakeArgumentPointer( - static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - in_element_op, - wei_element_op, - out_element_op); - - auto invoker_ptr = conv_ptr->MakeInvokerPointer(); - - if(conv_ptr->IsSupportedArgument(argument_ptr.get())) - { - std::string conv_name = conv_ptr->GetTypeString(); - - float ave_time = - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); - - std::size_t flop = - ck::utils::conv::get_flops(N, C, K, filter_spatial_lengths, output_spatial_lengths); - std::size_t num_btype = - ck::utils::conv::get_btype( - N, C, K, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths); - - float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s" << std::endl; - - if(tflops > best_tflops) - { - best_conv_name = conv_name; - best_tflops = tflops; - best_ave_time = ave_time; - best_gb_per_sec = gb_per_sec; - } - - if(do_verification) - { - in_device_buf.FromDevice(input_device_result.mData.data()); - - if(!check_out(input_host_result, input_device_result)) - { - std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl; - - success = false; - } - else - { - std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl; - } - - success = ck::utils::check_err(input_host_result, input_device_result); - - if(do_log) - { - std::cout << "in : "; - show_data_nhwc_layout(output); - std::cout << std::endl; - - std::cout << "wei: "; - show_data_nhwc_layout(weights); - std::cout << std::endl; - - std::cout << "out_host : "; - show_data_nhwc_layout(input_host_result); - std::cout << std::endl; - - std::cout << "out_device: "; - show_data_nhwc_layout(input_device_result); - std::cout << std::endl; - } - } - } - } - - std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " - << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; - return success; -} - -} // namespace profiler -} // namespace ck diff --git a/profiler/include/profiler/profile_convnd_bwd_weight_impl.hpp b/profiler/include/profiler/profile_convnd_bwd_weight_impl.hpp deleted file mode 100644 index e37c887a96f9c843abd62ee83f026da06f887daf..0000000000000000000000000000000000000000 --- a/profiler/include/profiler/profile_convnd_bwd_weight_impl.hpp +++ /dev/null @@ -1,474 +0,0 @@ -#pragma once - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/conv_util.hpp" -#include "ck/library/host_tensor/device_memory.hpp" -#include "ck/library/host_tensor/host_tensor.hpp" -#include "ck/library/host_tensor/host_tensor_generator.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp" - -using F16 = ck::half_t; -using F32 = float; -using BF16 = ck::bhalf_t; - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using DeviceConvndBwdWeightNoOpPtr = - DeviceConvBwdWeightPtr; - -void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances( - std::vector&); -void add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances( - std::vector&); -void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances( - std::vector&); - -void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances( - std::vector&); -void add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances( - std::vector&); -void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances( - std::vector&); - -void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_instances( - std::vector&); -void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_instances( - std::vector&); -void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( - std::vector&); - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck - -namespace ck { -namespace profiler { - -using DeviceConvndBwdWeightNoOpPtr = - ck::tensor_operation::device::instance::DeviceConvndBwdWeightNoOpPtr; - -template -HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector& dims, - int num_dim_spatial = 2) -{ - namespace tl = ck::tensor_layout::convolution; - - switch(num_dim_spatial) - { - case 3: { - return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{}); - } - case 2: { - return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{}); - } - case 1: { - return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{}); - } - default: { - throw std::runtime_error("Unsupported number of spatial dimensions provided!"); - } - } -} - -template -HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector& dims, - int num_dim_spatial = 2) -{ - namespace tl = ck::tensor_layout::convolution; - - switch(num_dim_spatial) - { - case 3: { - return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{}); - } - case 2: { - return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{}); - } - case 1: { - return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{}); - } - default: { - throw std::runtime_error("Unsupported number of spatial dimensions provided!"); - } - } -} - -template -HostTensorDescriptor get_output_host_ensor_descriptor(const std::vector& dims, - int num_dim_spatial = 2) -{ - namespace tl = ck::tensor_layout::convolution; - - switch(num_dim_spatial) - { - case 3: { - return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{}); - } - case 2: { - return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{}); - } - case 1: { - return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{}); - } - default: { - throw std::runtime_error("Unsupported number of spatial dimensions provided!"); - } - } -} - -template -void get_device_conv_bwd_weight_op_ptr( - InDataType, WeiDataType, OutDataType, std::vector&, int) -{ - std::cout << "can not find device conv bwd weight" << std::endl; - exit(1); -} - -template <> -void get_device_conv_bwd_weight_op_ptr( - F32, F32, F32, std::vector& conv_ptrs, int num_dim_spatial) -{ - switch(num_dim_spatial) - { - case 1: - ck::tensor_operation::device::instance:: - add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs); - break; - case 2: - ck::tensor_operation::device::instance:: - add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); - break; - case 3: - ck::tensor_operation::device::instance:: - add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs); - break; - default: break; - } -} - -template <> -void get_device_conv_bwd_weight_op_ptr( - F16, F16, F16, std::vector& conv_ptrs, int num_dim_spatial) -{ - switch(num_dim_spatial) - { - case 1: - ck::tensor_operation::device::instance:: - add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs); - break; - case 2: - ck::tensor_operation::device::instance:: - add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); - break; - case 3: - ck::tensor_operation::device::instance:: - add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs); - break; - default: break; - } -} - -template <> -void get_device_conv_bwd_weight_op_ptr( - BF16, BF16, BF16, std::vector& conv_ptrs, int num_dim_spatial) -{ - switch(num_dim_spatial) - { - case 1: - ck::tensor_operation::device::instance:: - add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs); - break; - case 2: - ck::tensor_operation::device::instance:: - add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); - break; - case 3: - ck::tensor_operation::device::instance:: - add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs); - break; - default: break; - } -} - -template -void show_data_nhwc_layout(Tensor& nhwc) -{ - std::cout << "["; - for(int n = 0; n < ck::type_convert(nhwc.mDesc.GetLengths()[0]); n++) - { - std::cout << "["; - for(int hi = 0; hi < ck::type_convert(nhwc.mDesc.GetLengths()[2]); hi++) - { - std::cout << "["; - for(int wi = 0; wi < ck::type_convert(nhwc.mDesc.GetLengths()[3]); wi++) - { - std::cout << "["; - for(int c = 0; c < ck::type_convert(nhwc.mDesc.GetLengths()[1]); c++) - { - std::cout << static_cast(nhwc(n, c, hi, wi)) << " "; - } - std::cout << "]"; - } - std::cout << "]"; - } - std::cout << "]"; - } - std::cout << "]"; -} - -template -bool profile_convnd_bwd_weight_impl(int do_verification, - int init_method, - bool do_log, - bool time_kernel, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::vector input_spatial_lengths, - std::vector filter_spatial_lengths, - std::vector output_spatial_lengths, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads, - ck::index_t split_k) -{ - using InElementOp = ck::tensor_operation::element_wise::PassThrough; - using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; - using OutElementOp = ck::tensor_operation::element_wise::PassThrough; - - const auto in_element_op = InElementOp{}; - const auto wei_element_op = WeiElementOp{}; - const auto out_element_op = OutElementOp{}; - - std::vector input_dims{static_cast(N), static_cast(C)}; - input_dims.insert( - std::end(input_dims), std::begin(input_spatial_lengths), std::end(input_spatial_lengths)); - - std::vector filter_dims{static_cast(K), static_cast(C)}; - filter_dims.insert(std::end(filter_dims), - std::begin(filter_spatial_lengths), - std::end(filter_spatial_lengths)); - - std::vector output_dims{static_cast(N), static_cast(K)}; - output_dims.insert(std::end(output_dims), - std::begin(output_spatial_lengths), - std::end(output_spatial_lengths)); - - Tensor input(get_input_host_tensor_descriptor(input_dims, NDimSpatial)); - Tensor weights_host_result( - get_filters_host_tensor_descriptor(filter_dims, NDimSpatial)); - Tensor weights_device_result( - get_filters_host_tensor_descriptor(filter_dims, NDimSpatial)); - Tensor output( - get_output_host_ensor_descriptor(output_dims, NDimSpatial)); - - std::cout << "input: " << input.mDesc << std::endl; - std::cout << "weights: " << weights_host_result.mDesc << std::endl; - std::cout << "output: " << output.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - input.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - output.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - break; - default: - input.GenerateTensorValue(GeneratorTensor_1{1}); - output.GenerateTensorValue(GeneratorTensor_1{1}); - } - - DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); - DeviceMem wei_device_buf(sizeof(WeiDataType) * weights_device_result.mDesc.GetElementSpace()); - DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace()); - - in_device_buf.ToDevice(input.mData.data()); - out_device_buf.ToDevice(output.mData.data()); - - // reset input to zero - wei_device_buf.SetZero(); - - if(do_verification) - { - auto RunReference = [&](auto& ref_conv) { - auto ref_invoker = ref_conv.MakeInvoker(); - - auto ref_argument = ref_conv.MakeArgument(input, - weights_host_result, - output, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); - ref_invoker.Run(ref_argument); - }; - - auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdWeight(); - RunReference(ref_conv); - } - - // add device Conv instances - std::vector conv_ptrs; - get_device_conv_bwd_weight_op_ptr( - InDataType{}, WeiDataType{}, OutDataType{}, conv_ptrs, NDimSpatial); - - if(conv_ptrs.size() <= 0) - { - throw std::runtime_error("wrong! no device Conv instance found"); - } - - std::string best_conv_name; - float best_ave_time = 0; - float best_tflops = 0; - float best_gb_per_sec = 0; - - // profile device Conv instances - bool success = true; - for(auto& conv_ptr : conv_ptrs) - { - // using atomic, so need to reset input, setzero is done in invoker - // if(split_k > 1) - //{ - // wei_device_buf.SetZero(); - //} - - auto argument_ptr = conv_ptr->MakeArgumentPointer( - static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - in_element_op, - wei_element_op, - out_element_op, - split_k); - - if(!conv_ptr->IsSupportedArgument(argument_ptr.get())) - { - std::cout << "wrong! device_conv with the specified compilation parameters does " - "not support this Conv problem" - << std::endl; - continue; - } - - auto invoker_ptr = conv_ptr->MakeInvokerPointer(); - std::string conv_name = conv_ptr->GetTypeString(); - float ave_time = 0; - - if(std::is_same::value && split_k > 1) - { - // alloc work space - size_t bwd_weight_workspace_size = conv_ptr->GetWorkSpaceSize(argument_ptr.get()); - if(bwd_weight_workspace_size <= 0) - { - printf("wrong work space size\n"); - exit(1); - } - DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size); - wei_work_space_device_buf.SetZero(); - conv_ptr->SetWorkSpacePointer(argument_ptr.get(), - wei_work_space_device_buf.GetDeviceBuffer()); - ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); - } - else - { - ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); - } - - std::size_t flop = - ck::utils::conv::get_flops(N, C, K, filter_spatial_lengths, output_spatial_lengths); - std::size_t num_btype = ck::utils::conv::get_btype( - N, C, K, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths); - - float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s" << std::endl; - - if(tflops > best_tflops) - { - best_conv_name = conv_name; - best_tflops = tflops; - best_ave_time = ave_time; - best_gb_per_sec = gb_per_sec; - } - - if(do_verification) - { - wei_device_buf.FromDevice(weights_device_result.mData.data()); - - success = ck::utils::check_err(weights_host_result, weights_device_result); - - if(success == false) - { - std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl; - } - else - { - std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl; - } - - if(do_log) - { - std::cout << "in : "; - show_data_nhwc_layout(output); - std::cout << std::endl; - - std::cout << "wei: "; - show_data_nhwc_layout(weights_host_result); - std::cout << std::endl; - - std::cout << "out : "; - show_data_nhwc_layout(input); - std::cout << std::endl; - - std::cout << "wei_device: "; - show_data_nhwc_layout(weights_device_result); - std::cout << std::endl; - } - } - } - - std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " - << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; - return success; -} - -} // namespace profiler -} // namespace ck diff --git a/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp b/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp index 7707e16b089ef3058dea988188f80e453fcd4e7d..1fd9c811095a823ce2746ddba63431ccdd77f20b 100644 --- a/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp +++ b/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_gemm_add_add_fastgelu_impl.hpp b/profiler/include/profiler/profile_gemm_add_add_fastgelu_impl.hpp index 3cc2ea3b92624da3db43696c181248e55a5f0256..81b8d8ddbf551f23de68beabd8900b4251c97c2a 100644 --- a/profiler/include/profiler/profile_gemm_add_add_fastgelu_impl.hpp +++ b/profiler/include/profiler/profile_gemm_add_add_fastgelu_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_gemm_add_fastgelu_impl.hpp b/profiler/include/profiler/profile_gemm_add_fastgelu_impl.hpp index d53a6589e0f505ed3beb6a7792c9045c49ee0c68..6f6d881c1e4ef5214fb13d3b7523c2b4c979fd49 100644 --- a/profiler/include/profiler/profile_gemm_add_fastgelu_impl.hpp +++ b/profiler/include/profiler/profile_gemm_add_fastgelu_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_gemm_add_multiply_impl.hpp b/profiler/include/profiler/profile_gemm_add_multiply_impl.hpp index 40093e774f0bcd770e3bc1434566d5d36c069f89..25871dfb2ec814fb5d0cda406a31ad8602b9dc16 100644 --- a/profiler/include/profiler/profile_gemm_add_multiply_impl.hpp +++ b/profiler/include/profiler/profile_gemm_add_multiply_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp b/profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp index e1c90f0f5255de2386b252acba5684910edeebf6..4c3d0a045054841b3f88a9c85ba9eaa8fd2538b7 100644 --- a/profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp +++ b/profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp b/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp index b4ec78cdf37659b6e838b10a8c235fdd6a0c9c64..c0ffea8a326c8bbc95a9b2d2e46da651089d807a 100644 --- a/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp +++ b/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_gemm_bilinear_impl.hpp b/profiler/include/profiler/profile_gemm_bilinear_impl.hpp index 31bae281c45b2e4e9bbdd8068fd3b084270c45ad..b540e938b5b5ab5798a6a18f86e452caf819c75e 100644 --- a/profiler/include/profiler/profile_gemm_bilinear_impl.hpp +++ b/profiler/include/profiler/profile_gemm_bilinear_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_gemm_fastgelu_impl.hpp b/profiler/include/profiler/profile_gemm_fastgelu_impl.hpp index f9a544c044f4165fbee8cab9ba10a10616f27e22..3893f8cdc7cfbf957fe28cd2b019092b59c88029 100644 --- a/profiler/include/profiler/profile_gemm_fastgelu_impl.hpp +++ b/profiler/include/profiler/profile_gemm_fastgelu_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_gemm_impl.hpp b/profiler/include/profiler/profile_gemm_impl.hpp index 9b164104b505d3d3912206b5fa1a6ee6aeb2ef48..eaab5dbcc2c8d3a6d5e1ea5d59c1c292f696d16d 100644 --- a/profiler/include/profiler/profile_gemm_impl.hpp +++ b/profiler/include/profiler/profile_gemm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_gemm_reduce_impl.hpp b/profiler/include/profiler/profile_gemm_reduce_impl.hpp index 370121a3ccff382539ee34994a05b4b035db5512..ff801e8afd3731e62bd6edb545cce87bbddeae74 100644 --- a/profiler/include/profiler/profile_gemm_reduce_impl.hpp +++ b/profiler/include/profiler/profile_gemm_reduce_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_gemm_splitk_impl.hpp b/profiler/include/profiler/profile_gemm_splitk_impl.hpp index e5d5f8765e1b25e67490da7444f31ac38c7aca27..6ffa31678f837a670f750dff8c6288d39cafb20a 100644 --- a/profiler/include/profiler/profile_gemm_splitk_impl.hpp +++ b/profiler/include/profiler/profile_gemm_splitk_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -72,8 +72,8 @@ bool profile_gemm_splitk_impl(int do_verification, { case 0: break; case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a_m_k.GenerateTensorValue(GeneratorTensor_2{-1, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-1, 2}); break; default: a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -94,7 +94,7 @@ bool profile_gemm_splitk_impl(int do_verification, a_device_buf.ToDevice(a_m_k.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data()); - c_device_buf.ToDevice(c_m_n_device_result.mData.data()); + c_device_buf.SetZero(); using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitK +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_grouped_conv_bwd_data_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param) +{ + using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto out_element_op = OutElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto in_element_op = InElementOp{}; + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + + Tensor out(out_g_n_k_wos_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor in_host(in_g_n_c_wis_desc); + Tensor in_device(in_g_n_c_wis_desc); + + std::cout << "out: " << out.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "in: " << in_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + out.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + out.GenerateTensorValue(GeneratorTensor_1{1}); + wei.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpaceSize()); + + out_device_buf.ToDevice(out.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + // reset input to zero + in_device_buf.SetZero(); + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData(); + + auto ref_invoker = ref_conv.MakeInvoker(); + + in_host.SetZero(); + + auto ref_argument = ref_conv.MakeArgument(in_host, + wei, + out, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + out_element_op, + wei_element_op, + in_element_op); + + ref_invoker.Run(ref_argument); + } + + std::string best_op_name; + float best_avg_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device op instances + bool pass = true; + + auto run_impl = [&](auto& op_ptr, auto& argument_ptr) { + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init output to zero before profiling next kernel + in_device_buf.SetZero(); + + std::string op_name = op_ptr->GetTypeString(); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + float avg_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + in_device_buf.FromDevice(in_device.mData.data()); + + pass = pass & ck::utils::check_err(in_device, in_host); + + if(do_log) + { + LogRangeAsType(std::cout << "output : ", out.mData, ",") << std::endl; + LogRangeAsType(std::cout << "weight: ", wei.mData, ",") << std::endl; + LogRangeAsType(std::cout << "in_host : ", in_host.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "in_device: ", in_device.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + }; + + // do GEMM + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD, + InLayout, + OutDataType, + WeiDataType, + ck::Tuple<>, + InDataType, + OutElementOp, + WeiElementOp, + InElementOp>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::array out_lengths{}; + std::array out_strides{}; + std::array wei_lengths{}; + std::array wei_strides{}; + std::array in_lengths{}; + std::array in_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(out_g_n_k_wos_desc.GetLengths(), out_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), out_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), wei_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), wei_strides); + copy(in_g_n_c_wis_desc.GetLengths(), in_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), in_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + {}, + static_cast(in_device_buf.GetDeviceBuffer()), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + out_element_op, + wei_element_op, + in_element_op); + + run_impl(op_ptr, argument_ptr); + } + + std::cout << "Best configuration parameters:" + << "\nname: " << best_op_name << "\navg_time: " << best_avg_time + << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp index 4f9aa98376c9e56058f04938f8171ac93c230e3e..dc6739773ff5a282f2eeaca06b7979f0b07e8482 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp index b201a2ed331cbf4a0e3951d8d69027f018af09f6..9fadfe96992a8d0374e784d1a95c4ac7989562bd 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_grouped_gemm_fastgelu_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_fastgelu_impl.hpp index 87e6ae44c76e56919fb17ad232586ae6294f1b68..f05b13b7495dfefee263f51d5404dda3dbc598dc 100644 --- a/profiler/include/profiler/profile_grouped_gemm_fastgelu_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_fastgelu_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_grouped_gemm_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_impl.hpp index 2c28850e651360b8e6185635eec9362456136e5d..09a651d77c994e69635cb4891fa0cd55ae52791b 100644 --- a/profiler/include/profiler/profile_grouped_gemm_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -8,6 +8,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp" @@ -18,6 +19,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" +#include "ck/library/utility/fill.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" namespace ck { @@ -39,9 +41,9 @@ bool profile_grouped_gemm_impl(int do_verification, const std::vector& Ks, const std::vector& StrideAs, const std::vector& StrideBs, - const std::vector& StrideCs) + const std::vector& StrideCs, + int kbatch = 1) { - bool pass = true; auto f_host_tensor_descriptor = @@ -79,11 +81,11 @@ bool profile_grouped_gemm_impl(int do_verification, c_m_n_device_results.push_back( Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); - +#if DEBUG_LOG std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i << "]:" << c_m_n_device_results[i].mDesc << std::endl; - +#endif // DEBUG_LOG std::size_t num_thread = 1; switch(init_method) { @@ -96,8 +98,6 @@ bool profile_grouped_gemm_impl(int do_verification, a_m_k[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); b_k_n[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); } - - c_m_n_device_results[i].GenerateTensorValue(GeneratorTensor_0{}, num_thread); } using AElementOp = ck::tensor_operation::element_wise::PassThrough; @@ -132,13 +132,12 @@ bool profile_grouped_gemm_impl(int do_verification, std::make_unique(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize())); b_device_buf.emplace_back( std::make_unique(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize())); - c_device_buf.emplace_back(std::make_unique( sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize())); a_device_buf[i]->ToDevice(a_m_k[i].mData.data()); b_device_buf[i]->ToDevice(b_k_n[i].mData.data()); - c_device_buf[i]->ToDevice(c_m_n_device_results[i].mData.data()); + c_device_buf[i]->SetZero(); gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); @@ -192,43 +191,71 @@ bool profile_grouped_gemm_impl(int do_verification, DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get())); gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); + std::string gemm_name = gemm_ptr->GetTypeString(); + + if(kbatch > 1) + { + using DeviceOpSplitK = + ck::tensor_operation::device::DeviceGroupedGemmSplitK, + CLayout, + ADataType, + BDataType, + ck::Tuple<>, + CDataType, + AElementOp, + BElementOp, + CElementOp>; + + if(dynamic_cast(gemm_ptr.get()) != nullptr) + { + dynamic_cast(gemm_ptr.get()) + ->SetKBatchSize(argument_ptr.get(), kbatch); + } + } if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) { - std::string gemm_name = gemm_ptr->GetTypeString(); float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); - std::size_t flop = 0, num_btype = 0; - for(std::size_t i = 0; i < gemm_descs.size(); i++) + if(time_kernel) { - flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; + std::size_t flop = 0, num_btype = 0; + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; - num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + sizeof(BDataType) * Ks[i] * Ns[i] + - sizeof(CDataType) * Ms[i] * Ns[i]; - } + num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + + sizeof(BDataType) * Ks[i] * Ns[i] + + sizeof(CDataType) * Ms[i] * Ns[i]; + } - float tflops = static_cast(flop) / 1.E9 / ave_time; + float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " - << gb_per_sec << " GB/s, " << gemm_name << std::endl; + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << std::endl; - if(tflops > best_tflops) - { - best_gemm_name = gemm_name; - best_tflops = tflops; - best_ave_time = ave_time; - best_gb_per_sec = gb_per_sec; + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } } if(do_verification) { + bool instance_pass = true; for(std::size_t i = 0; i < gemm_descs.size(); i++) { c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); + c_device_buf[i]->SetZero(); Tensor c_m_n_host_result( f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})); @@ -253,7 +280,20 @@ bool profile_grouped_gemm_impl(int do_verification, c_element_op); ref_invoker.Run(ref_argument); - pass = pass && ck::utils::check_err(c_m_n_device_results[i], c_m_n_host_result); + if(std::is_same_v && kbatch > 1) + { + instance_pass = + instance_pass && ck::utils::check_err(c_m_n_device_results[i], + c_m_n_host_result, + "Error: Incorrect results!", + 0.06); + } + else + { + instance_pass = + instance_pass && + ck::utils::check_err(c_m_n_device_results[i], c_m_n_host_result); + } if(do_log) { @@ -268,16 +308,25 @@ bool profile_grouped_gemm_impl(int do_verification, << std::endl; } } + + std::cout << "Instance: " << gemm_name << " verification " + << (instance_pass ? "SUCCEED" : "FAILED") << std::endl; + + pass = pass && instance_pass; } } else { - std::cout << "does not support this GEMM problem" << std::endl; + std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem" + << std::endl; } } - std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " - << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; + if(time_kernel) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; + } return pass; } diff --git a/profiler/include/profiler/profile_groupnorm_impl.hpp b/profiler/include/profiler/profile_groupnorm_impl.hpp index 81fec5590a8463ed68e1cf9797800f99e3fbb3c3..ebefe3dad42a91549b70ea26e171c40b78561eeb 100644 --- a/profiler/include/profiler/profile_groupnorm_impl.hpp +++ b/profiler/include/profiler/profile_groupnorm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -190,9 +190,9 @@ bool profile_groupnorm_impl(int do_verification, if(time_kernel) { - LogRange(std::cout << "length = ", length, ",") << ", "; - std::cout << "num_kernel = " << num_kernel << ", best perf = " << best_avg_time << " ms, " - << best_gb_per_sec << " GB/s, " << best_instance_name << std::endl; + LogRange(std::cout << "length = ", length, ",") << std::endl; + std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_instance_name << std::endl; } if(num_kernel == 0) diff --git a/profiler/include/profiler/profile_layernorm_impl.hpp b/profiler/include/profiler/profile_layernorm_impl.hpp index 7dd90d079775ffe2da37d4e63e8391e1ee3709d0..2d87c8c8fe98bb32c84d3ffb851e5bd9a4f51525 100644 --- a/profiler/include/profiler/profile_layernorm_impl.hpp +++ b/profiler/include/profiler/profile_layernorm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_pool2d_fwd_impl.hpp b/profiler/include/profiler/profile_pool2d_fwd_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0c888db1f47a6511de4822c1b27a23b2a2d6e814 --- /dev/null +++ b/profiler/include/profiler/profile_pool2d_fwd_impl.hpp @@ -0,0 +1,264 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/pool2d_fwd.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_pool2d_fwd_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + std::vector in_length, // NCHW + std::vector window_spatial_lengths, + std::vector window_strides, + std::vector input_left_pads, + std::vector input_right_pads) +{ + constexpr index_t InOutRank = 4; + constexpr index_t WindowRank = 2; + + if(in_length.size() != InOutRank || window_spatial_lengths.size() != WindowRank || + window_strides.size() != WindowRank || input_left_pads.size() != WindowRank || + input_right_pads.size() != WindowRank) + return false; + + std::vector out_length(InOutRank); + + int N = in_length[0]; + int C = in_length[1]; + + out_length[0] = N; + out_length[1] = C; + + // Calculate Ho, Wo + for(int i = 2; i < InOutRank; ++i) + { + auto pad1 = input_left_pads[i - 2]; + auto pad2 = input_right_pads[i - 2]; + auto windows_size = window_spatial_lengths[i - 2]; + auto windows_stride = window_strides[i - 2]; + out_length[i] = (in_length[i] + pad1 + pad2 - windows_size) / windows_stride + 1; + } + + int Hi = in_length[2]; + int Wi = in_length[3]; + int Ho = out_length[2]; + int Wo = out_length[3]; + + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W) { + using namespace ck::literals; + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}); + }; + + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi)); + Tensor out_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo)); + Tensor out_indices_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo)); + + Tensor out_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo)); + Tensor out_indices_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo)); + + switch(init_method) + { + case 0: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1{}); break; + case 1: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; + default: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * + out_n_c_ho_wo_device.mDesc.GetElementSpaceSize()); + DeviceMem out_indices_device_buf(sizeof(IndexDataType) * + out_indices_n_c_ho_wo_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + + // add device normalization instances + using DeviceOp = ck::tensor_operation::device::DevicePoolFwd; + + // get device op instances + const auto instance_ptrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << instance_ptrs.size() << " instances" << std::endl; + + std::string best_instance_name; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + if(do_verification) + { + using ReferenceInstance = ck::tensor_operation::host::ReferencePoolingFwd; + + ReferenceInstance ref; + auto ref_argument = ref.MakeArgument(in_n_c_hi_wi, + out_n_c_ho_wo_host, + out_indices_n_c_ho_wo_host, + window_spatial_lengths, + window_strides, + input_left_pads, + input_right_pads); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + } + + int num_kernel = 0; + + for(auto& inst_ptr : instance_ptrs) + { + auto argument_ptr = inst_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(out_indices_device_buf.GetDeviceBuffer()), + in_length, + window_spatial_lengths, + out_length, + {C * Hi * Wi, 1, Wi * C, C}, + {C * Ho * Wo, 1, Wo * C, C}, + {C * Ho * Wo, 1, Wo * C, C}, + window_strides, + input_left_pads, + input_right_pads, + {2, 3}); + + if(inst_ptr->IsSupportedArgument(argument_ptr.get())) + { + ++num_kernel; + } + else + { + if(time_kernel) + { + std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: "; + LogRange(std::cout << "input lengths = ", in_length, ", ") << std::endl; + } + + continue; + } + + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_bytes = in_n_c_hi_wi.mDesc.GetElementSize() * sizeof(InDataType) + + out_n_c_ho_wo_host.mDesc.GetElementSize() * sizeof(OutDataType); + + if constexpr(OutputIndex) + num_bytes += out_indices_n_c_ho_wo_host.mDesc.GetElementSize() * sizeof(IndexDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + if(time_kernel) + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << inst_ptr->GetTypeString() << std::endl; + + if(avg_time < best_avg_time) + { + best_instance_name = inst_ptr->GetTypeString(); + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data()); + + bool pass = ck::utils::check_err(out_n_c_ho_wo_device.mData, + out_n_c_ho_wo_host.mData, + "Error: Incorrect results", + 1e-3, + 1e-3); + + if constexpr(OutputIndex) + { + out_indices_device_buf.FromDevice(out_indices_n_c_ho_wo_device.mData.data()); + + pass = pass && ck::utils::check_err(out_indices_n_c_ho_wo_device, + out_indices_n_c_ho_wo_host); + } + + if(do_log) + { + LogRangeAsType(std::cout << "in_n_c_hi_wi : ", in_n_c_hi_wi.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_n_c_ho_wo_host : ", out_n_c_ho_wo_host.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_n_c_ho_wo_device : ", out_n_c_ho_wo_device.mData, ",") + << std::endl; + + if constexpr(OutputIndex) + LogRangeAsType(std::cout << "out_indices_n_c_ho_wo_device : ", + out_indices_n_c_ho_wo_device.mData, + ",") + << std::endl; + } + + if(!pass) + { + std::cout << inst_ptr->GetTypeString() << " failed verification: "; + LogRange(std::cout << "lengths = [", in_length, ", ") << "]." << std::endl; + return false; + } + else + { + if(time_kernel) + std::cout << "pass" << std::endl; + } + } + } + + if(time_kernel) + { + LogRange(std::cout << "length = ", in_length, ",") << std::endl; + std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_instance_name << std::endl; + } + + if(num_kernel == 0) + { + std::cout << "Error: No kernel is applicable" << std::endl; + return false; + } + + return true; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_pool3d_fwd_impl.hpp b/profiler/include/profiler/profile_pool3d_fwd_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..41b57fd85332b6007e3c4d44ef11574ddf044498 --- /dev/null +++ b/profiler/include/profiler/profile_pool3d_fwd_impl.hpp @@ -0,0 +1,271 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_pool3d_fwd_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + std::vector in_length, // NCDHW + std::vector window_spatial_lengths, + std::vector window_strides, + std::vector input_left_pads, + std::vector input_right_pads) +{ + constexpr index_t InOutRank = 5; + constexpr index_t WindowRank = 3; + + if(in_length.size() != InOutRank || window_spatial_lengths.size() != WindowRank || + window_strides.size() != WindowRank || input_left_pads.size() != WindowRank || + input_right_pads.size() != WindowRank) + return false; + + std::vector out_length(InOutRank); + + int N = in_length[0]; + int C = in_length[1]; + + out_length[0] = N; + out_length[1] = C; + + // Calculate Do, Ho, Wo + for(int i = 2; i < InOutRank; ++i) + { + auto pad1 = input_left_pads[i - 2]; + auto pad2 = input_right_pads[i - 2]; + auto windows_size = window_spatial_lengths[i - 2]; + auto windows_stride = window_strides[i - 2]; + out_length[i] = (in_length[i] + pad1 + pad2 - windows_size) / windows_stride + 1; + } + + int Di = in_length[2]; + int Hi = in_length[3]; + int Wi = in_length[4]; + int Do = out_length[2]; + int Ho = out_length[3]; + int Wo = out_length[4]; + + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t D, std::size_t H, std::size_t W) { + using namespace ck::literals; + + return HostTensorDescriptor({N_, C_, D, H, W}, + {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); + }; + + Tensor in_n_c_di_hi_wi(f_host_tensor_descriptor(N, C, Di, Hi, Wi)); + Tensor out_n_c_do_ho_wo_host(f_host_tensor_descriptor(N, C, Do, Ho, Wo)); + Tensor out_indices_n_c_do_ho_wo_host(f_host_tensor_descriptor(N, C, Do, Ho, Wo)); + + Tensor out_n_c_do_ho_wo_device(f_host_tensor_descriptor(N, C, Do, Ho, Wo)); + Tensor out_indices_n_c_do_ho_wo_device( + f_host_tensor_descriptor(N, C, Do, Ho, Wo)); + + switch(init_method) + { + case 0: in_n_c_di_hi_wi.GenerateTensorValue(GeneratorTensor_1{}); break; + case 1: in_n_c_di_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; + default: in_n_c_di_hi_wi.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_di_hi_wi.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * + out_n_c_do_ho_wo_device.mDesc.GetElementSpaceSize()); + DeviceMem out_indices_device_buf(sizeof(IndexDataType) * + out_indices_n_c_do_ho_wo_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in_n_c_di_hi_wi.mData.data()); + + // add device normalization instances + using DeviceOp = ck::tensor_operation::device::DevicePoolFwd; + + // get device op instances + const auto instance_ptrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << instance_ptrs.size() << " instances" << std::endl; + + std::string best_instance_name; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + if(do_verification) + { + using ReferenceInstance = ck::tensor_operation::host::ReferencePoolingFwd; + + ReferenceInstance ref; + auto ref_argument = ref.MakeArgument(in_n_c_di_hi_wi, + out_n_c_do_ho_wo_host, + out_indices_n_c_do_ho_wo_host, + window_spatial_lengths, + window_strides, + input_left_pads, + input_right_pads); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + } + + int num_kernel = 0; + + for(auto& inst_ptr : instance_ptrs) + { + auto argument_ptr = inst_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(out_indices_device_buf.GetDeviceBuffer()), + in_length, + window_spatial_lengths, + out_length, + {Di * C * Hi * Wi, 1, C * Hi * Wi, Wi * C, C}, + {Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C}, + {Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C}, + window_strides, + input_left_pads, + input_right_pads, + {2, 3, 4}); + + if(inst_ptr->IsSupportedArgument(argument_ptr.get())) + { + ++num_kernel; + } + else + { + if(time_kernel) + { + std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: "; + LogRange(std::cout << "input lengths = ", in_length, ", ") << std::endl; + } + + continue; + } + + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_bytes = in_n_c_di_hi_wi.mDesc.GetElementSize() * sizeof(InDataType) + + out_n_c_do_ho_wo_host.mDesc.GetElementSize() * sizeof(OutDataType); + + if constexpr(OutputIndex) + num_bytes += + out_indices_n_c_do_ho_wo_host.mDesc.GetElementSize() * sizeof(IndexDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + if(time_kernel) + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << inst_ptr->GetTypeString() << std::endl; + + if(avg_time < best_avg_time) + { + best_instance_name = inst_ptr->GetTypeString(); + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(out_n_c_do_ho_wo_device.mData.data()); + + bool pass = ck::utils::check_err(out_n_c_do_ho_wo_device.mData, + out_n_c_do_ho_wo_host.mData, + "Error: Incorrect results", + 1e-3, + 1e-3); + + if constexpr(OutputIndex) + { + out_indices_device_buf.FromDevice(out_indices_n_c_do_ho_wo_device.mData.data()); + + pass = pass && ck::utils::check_err(out_indices_n_c_do_ho_wo_device, + out_indices_n_c_do_ho_wo_host); + } + + if(do_log) + { + LogRangeAsType( + std::cout << "in_n_c_di_hi_wi : ", in_n_c_di_hi_wi.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_n_c_do_ho_wo_host : ", out_n_c_do_ho_wo_host.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_n_c_do_ho_wo_device : ", out_n_c_do_ho_wo_device.mData, ",") + << std::endl; + + if constexpr(OutputIndex) + LogRangeAsType(std::cout << "out_indices_n_c_do_ho_wo_device : ", + out_indices_n_c_do_ho_wo_device.mData, + ",") + << std::endl; + } + + if(!pass) + { + std::cout << inst_ptr->GetTypeString() << " failed verification: "; + LogRange(std::cout << "lengths = [", in_length, ", ") << "]." << std::endl; + return false; + } + else + { + if(time_kernel) + std::cout << "pass" << std::endl; + } + } + } + + if(time_kernel) + { + LogRange(std::cout << "length = ", in_length, ",") << std::endl; + std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_instance_name << std::endl; + } + + if(num_kernel == 0) + { + std::cout << "Error: No kernel is applicable" << std::endl; + return false; + } + + return true; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_reduce_impl.hpp b/profiler/include/profiler/profile_reduce_impl.hpp index e6182002999b8c9b0838dab70d387d92d59e6cdc..b54aa65aef727b048458e0b63f638bad143e4e28 100644 --- a/profiler/include/profiler/profile_reduce_impl.hpp +++ b/profiler/include/profiler/profile_reduce_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_softmax_impl.hpp b/profiler/include/profiler/profile_softmax_impl.hpp index 96816f53bbb7da01435569cb7aa6c8e2ea099586..daaf565149784729fc66c4c6ef048df47c0c8dae 100644 --- a/profiler/include/profiler/profile_softmax_impl.hpp +++ b/profiler/include/profiler/profile_softmax_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -40,7 +40,11 @@ template <> std::string type_to_string() { return "int8"; } template <> std::string type_to_string() { return "int32"; } // clang-format on -template +template bool profile_softmax_impl(int do_verification, int init_method, bool do_log, @@ -54,7 +58,13 @@ bool profile_softmax_impl(int do_verification, if(Rank != in_length.size()) { throw std::runtime_error("Input tensor rank is different from template argument Rank!"); - } + }; + + if(NumReduceDim != reduce_dims.size()) + { + throw std::runtime_error( + "Input reduce_dims rank is different from template argument NumReduceDim!"); + }; Tensor in = in_strides.empty() ? Tensor(in_length) : Tensor(in_length, in_strides); @@ -92,8 +102,13 @@ bool profile_softmax_impl(int do_verification, // add device softmax instances using PassThrough = ck::tensor_operation::element_wise::PassThrough; - using DeviceOp = tensor_operation::device:: - DeviceSoftmax; + using DeviceOp = tensor_operation::device::DeviceSoftmax; // get device op instances const auto instances = tensor_operation::device::instance::DeviceOperationInstanceFactory< @@ -112,13 +127,6 @@ bool profile_softmax_impl(int do_verification, for(auto& inst_ptr : instances) { - // Is this user's responsibility to check if problem mismatches kernel instance (ie. rank 3 - // problem to rank 4 kernel) other than invoking IsSupportedArgument()? - if(!(inst_ptr->GetNumReduceDim() == static_cast(reduce_dims.size()))) - { - continue; - } - auto argument_ptr = inst_ptr->MakeArgumentPointer(in_tensor_lengths, in_tensor_strides, reduce_dims, diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index c21fff7de90dede1ac836fda2bdc992db62eacff..3f19096f48db259239de51a5dfc82059b0aca73a 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -25,11 +25,17 @@ set(PROFILER_SOURCES profile_reduce.cpp profile_groupnorm.cpp profile_layernorm.cpp + profile_avg_pool2d_fwd.cpp + profile_max_pool3d_fwd.cpp profile_softmax.cpp profile_batchnorm_fwd.cpp profile_batchnorm_bwd.cpp profile_batchnorm_infer.cpp profile_grouped_gemm_fastgelu.cpp + profile_contraction_bilinear.cpp + profile_contraction_scale.cpp + profile_batched_gemm_multi_d.cpp + profile_grouped_conv_bwd_data.cpp ) set(PROFILER_EXECUTABLE ckProfiler) @@ -70,4 +76,9 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool_fwd_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler) diff --git a/profiler/src/profile_avg_pool2d_fwd.cpp b/profiler/src/profile_avg_pool2d_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c67897c04421085b69cb03e3ce97faf259d4f80b --- /dev/null +++ b/profiler/src/profile_avg_pool2d_fwd.cpp @@ -0,0 +1,141 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "profiler/data_type_enum.hpp" +#include "profiler/profile_pool2d_fwd_impl.hpp" +#include "profiler_operation_registry.hpp" + +using ck::index_t; + +struct avgPoolFwdArgParser +{ + std::unordered_map> long_opts = { + {"length", {}}, {"wsize", {}}, {"wstride", {}}, {"pad1", {}}, {"pad2", {}}}; + + bool parse_opt(int argc, char* argv[], const std::string& key, int i) + { + if(std::string("--") + key == argv[i]) + { + int pos = i; + while(++i < argc && argv[i][0] != '-') {} + int end = i; + for(int j = pos + 1; j < end; j++) + { + long_opts[key].push_back(std::stoi(argv[j])); + } + return true; + } + return false; + } + + void operator()(int argc, char* argv[]) + { + for(auto& kv : long_opts) + { + for(int i = 1; i < argc; i++) + { + if(parse_opt(argc, argv, kv.first, i)) + break; + } + } + } +}; + +void print_help_avg_pool2d_fwd() +{ + std::cout << "arg1: data type (0: fp16; 1: fp32)\n" + << "arg2: verification (0: no; 1: yes)\n" + << "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg4: print tensor value (0: no; 1: yes)\n" + << "arg5: time kernel (0=no, 1=yes)\n" + << "--length: input tensor length for NDHW(e.g, --length 2 32 30 30) \n" + << "--wsize: window size for YX (e.g, --wsize 2 2) \n" + << "--wstride: window stride for HW (e.g, --wstride 2 2) \n" + << "--pad1: left side of padding in HW (e.g, --pad1 1 1) \n" + << "--pad2: right side of padding in HW (e.g, --pad2 1 1) \n" + << "eg: ckProfiler avg_pool2d_fwd 0 1 2 0 1 0 --length 2 32 30 30 --wsize 2 2 " + "--wstride 2 2 --pad1 1 1 --pad2 1 1" + << std::endl; +} + +int profile_avg_pool2d_fwd(int argc, char* argv[]) +{ + ck::DataTypeEnum data_type = ck::DataTypeEnum::Half; + bool do_verification = true; + int init_method = 0; + bool do_log = false; + bool time_kernel = true; + + std::vector in_length = {2, 32, 30, 30}; + std::vector wsize = {2, 2}; + std::vector wstride = {2, 2}; + std::vector pad1 = {1, 1}; + std::vector pad2 = {1, 1}; + + if(argc != 2 && argc != 25) + { + print_help_avg_pool2d_fwd(); + return 0; + } + else if(argc == 25) + { + data_type = static_cast(std::stoi(argv[2])); + do_verification = std::stoi(argv[3]); + init_method = std::stoi(argv[4]); + do_log = std::stoi(argv[5]); + time_kernel = std::stoi(argv[6]); + + // parse the long options + avgPoolFwdArgParser arg_parser; + arg_parser(argc, argv); + in_length = arg_parser.long_opts["length"]; + wsize = arg_parser.long_opts["wsize"]; + wstride = arg_parser.long_opts["wstride"]; + pad1 = arg_parser.long_opts["pad1"]; + pad2 = arg_parser.long_opts["pad2"]; + } + + using F16 = ck::half_t; + using F32 = float; + using I32 = int32_t; + constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; + + if(data_type == ck::DataTypeEnum::Half) + { + ck::profiler::profile_pool2d_fwd_impl( + do_verification, + init_method, + do_log, + time_kernel, + in_length, + wsize, + wstride, + pad1, + pad2); + } + else if(data_type == ck::DataTypeEnum::Float) + { + ck::profiler::profile_pool2d_fwd_impl( + do_verification, + init_method, + do_log, + time_kernel, + in_length, + wsize, + wstride, + pad1, + pad2); + } + else + { + throw std::runtime_error("not implemented yet"); + } + + return 0; +} + +REGISTER_PROFILER_OPERATION("avg_pool2d_fwd", "avg_pool2d fwd", profile_avg_pool2d_fwd); diff --git a/profiler/src/profile_batched_gemm.cpp b/profiler/src/profile_batched_gemm.cpp index 907a373794f2d8a59a066b94a591c4890c5c4038..222532b7bbdeb36bfa9953fb9f174baef3b548ac 100644 --- a/profiler/src/profile_batched_gemm.cpp +++ b/profiler/src/profile_batched_gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -10,6 +10,8 @@ #include "profiler/profile_batched_gemm_impl.hpp" #include "profiler_operation_registry.hpp" +#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp" + enum struct GemmMatrixLayout { MK_KN_MN, // 0 @@ -78,55 +80,72 @@ int profile_batched_gemm(int argc, char* argv[]) using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; - auto profile = [&](auto a_type, - auto b_type, - auto c_type, - auto a_layout, - auto b_layout, - auto c_layout) { - using ADataType = decltype(a_type); - using BDataType = decltype(b_type); - using CDataType = decltype(c_type); - - using ALayout = decltype(a_layout); - using BLayout = decltype(b_layout); - using CLayout = decltype(c_layout); - - const int DefaultStrideA = ck::is_same_v ? K : M; - const int DefaultStrideB = ck::is_same_v ? N : K; - const int DefaultStrideC = ck::is_same_v ? N : M; - - const int StrideA_ = (StrideA < 0) ? DefaultStrideA : StrideA; - const int StrideB_ = (StrideB < 0) ? DefaultStrideB : StrideB; - const int StrideC_ = (StrideC < 0) ? DefaultStrideC : StrideC; - - const int DefaultBatchStrideA = (ck::is_same_v ? M : K) * StrideA_; - const int DefaultBatchStrideB = (ck::is_same_v ? K : N) * StrideB_; - const int DefaultBatchStrideC = (ck::is_same_v ? M : N) * StrideC_; - - const int BatchStrideA_ = (BatchStrideA < 0) ? DefaultBatchStrideA : BatchStrideA; - const int BatchStrideB_ = (BatchStrideB < 0) ? DefaultBatchStrideB : BatchStrideB; - const int BatchStrideC_ = (BatchStrideC < 0) ? DefaultBatchStrideC : BatchStrideC; - - bool pass = ck::profiler:: - profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - BatchStrideA_, - BatchStrideB_, - BatchStrideC_, - StrideA_, - StrideB_, - StrideC_, - BatchCount); - - return pass ? 0 : 1; - }; + auto profile = + [&](auto a_type, auto b_type, auto c_type, auto a_layout, auto b_layout, auto c_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using CDataType = decltype(c_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CLayout = decltype(c_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideC = ck::is_same_v ? N : M; + + const int StrideA_ = (StrideA < 0) ? DefaultStrideA : StrideA; + const int StrideB_ = (StrideB < 0) ? DefaultStrideB : StrideB; + const int StrideC_ = (StrideC < 0) ? DefaultStrideC : StrideC; + + const int DefaultBatchStrideA = (ck::is_same_v ? M : K) * StrideA_; + const int DefaultBatchStrideB = (ck::is_same_v ? K : N) * StrideB_; + const int DefaultBatchStrideC = (ck::is_same_v ? M : N) * StrideC_; + + const int BatchStrideA_ = (BatchStrideA < 0) ? DefaultBatchStrideA : BatchStrideA; + const int BatchStrideB_ = (BatchStrideB < 0) ? DefaultBatchStrideB : BatchStrideB; + const int BatchStrideC_ = (BatchStrideC < 0) ? DefaultBatchStrideC : BatchStrideC; + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + using DeviceOp = ck::tensor_operation::device::DeviceBatchedGemm; + + bool pass = ck::profiler::profile_batched_gemm_impl(do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + BatchStrideA_, + BatchStrideB_, + BatchStrideC_, + StrideA_, + StrideB_, + StrideC_, + BatchCount); + + return pass ? 0 : 1; + }; if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) { diff --git a/profiler/src/profile_batched_gemm_add_relu_gemm_add.cpp b/profiler/src/profile_batched_gemm_add_relu_gemm_add.cpp index f440a3094eb16b4eef1c5dee46cb08f6ab25d933..3d29c4b84aa5abfafbb220c46205413adb534770 100644 --- a/profiler/src/profile_batched_gemm_add_relu_gemm_add.cpp +++ b/profiler/src/profile_batched_gemm_add_relu_gemm_add.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_batched_gemm_gemm.cpp b/profiler/src/profile_batched_gemm_gemm.cpp index 6015c93be35268c1c87c129d0cc7d2c9bdf0584f..9a99874d1c7041f692133a6ef2e792bdf0dd8105 100644 --- a/profiler/src/profile_batched_gemm_gemm.cpp +++ b/profiler/src/profile_batched_gemm_gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_batched_gemm_multi_d.cpp b/profiler/src/profile_batched_gemm_multi_d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..98b462d9503438818554a8327b2f09d12943f154 --- /dev/null +++ b/profiler/src/profile_batched_gemm_multi_d.cpp @@ -0,0 +1,190 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "profiler/profile_batched_gemm_impl.hpp" +#include "profiler_operation_registry.hpp" + +#include "ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 +}; + +enum struct GemmDataType +{ + F16_F16_F16, // 0 + INT8_INT8_INT8, // 1 +}; + +#define OP_NAME "batched_gemm_multi_d" +#define OP_DESC "Batched GEMM multi D" + +int profile_batched_gemm_multi_d(int argc, char* argv[]) +{ + if(argc != 18) + { + // clang-format off + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: fp16; 1: int8)\n"); + printf("arg3: matrix layout (0: A[g, m, k] * B[g, k, n] = C[g, m, n];\n"); + printf(" 1: A[g, m, k] * B[g, n, k] = C[g, m, n];\n"); + printf(" 2: A[g, k, m] * B[g, k, n] = C[g, m, n];\n"); + printf(" 3: A[g, k, m] * B[g, n, k] = C[g, m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); + printf("arg8 to 17: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount\n"); + // clang-format on + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + + const int BatchStrideA = std::stoi(argv[14]); + const int BatchStrideB = std::stoi(argv[15]); + const int BatchStrideC = std::stoi(argv[16]); + + const int BatchCount = std::stoi(argv[17]); + + using F16 = ck::half_t; + using INT8 = int8_t; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = + [&](auto a_type, auto b_type, auto c_type, auto a_layout, auto b_layout, auto c_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using CDataType = decltype(c_type); + using DsDataType = ck::Tuple<>; + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CLayout = decltype(c_layout); + using DsLayout = ck::Tuple<>; + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideC = ck::is_same_v ? N : M; + + const int StrideA_ = (StrideA < 0) ? DefaultStrideA : StrideA; + const int StrideB_ = (StrideB < 0) ? DefaultStrideB : StrideB; + const int StrideC_ = (StrideC < 0) ? DefaultStrideC : StrideC; + + const int DefaultBatchStrideA = (ck::is_same_v ? M : K) * StrideA_; + const int DefaultBatchStrideB = (ck::is_same_v ? K : N) * StrideB_; + const int DefaultBatchStrideC = (ck::is_same_v ? M : N) * StrideC_; + + const int BatchStrideA_ = (BatchStrideA < 0) ? DefaultBatchStrideA : BatchStrideA; + const int BatchStrideB_ = (BatchStrideB < 0) ? DefaultBatchStrideB : BatchStrideB; + const int BatchStrideC_ = (BatchStrideC < 0) ? DefaultBatchStrideC : BatchStrideC; + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + using DeviceOp = ck::tensor_operation::device::DeviceBatchedGemmMultiD; + + bool pass = ck::profiler::profile_batched_gemm_impl(do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + BatchStrideA_, + BatchStrideB_, + BatchStrideC_, + StrideA_, + StrideB_, + StrideC_, + BatchCount); + + return pass ? 0 : 1; + }; + + if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(F16{}, F16{}, F16{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F16{}, F16{}, F16{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(F16{}, F16{}, F16{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(F16{}, F16{}, F16{}, Col{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(INT8{}, INT8{}, INT8{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(INT8{}, INT8{}, INT8{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(INT8{}, INT8{}, INT8{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(INT8{}, INT8{}, INT8{}, Col{}, Col{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_batched_gemm_multi_d); diff --git a/profiler/src/profile_batched_gemm_reduce.cpp b/profiler/src/profile_batched_gemm_reduce.cpp index 6b1dfc01427a1cd34eefc89584bb6f3ce8bf206f..9620d63cafbeedfb5398b553b5172418923bb8e0 100644 --- a/profiler/src/profile_batched_gemm_reduce.cpp +++ b/profiler/src/profile_batched_gemm_reduce.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_batchnorm_bwd.cpp b/profiler/src/profile_batchnorm_bwd.cpp index 44ce7350ff06eaad3eff7d93f60c022d1a4c2b60..1738d53dbe7c6d2581e15df8bae58f7c48733999 100644 --- a/profiler/src/profile_batchnorm_bwd.cpp +++ b/profiler/src/profile_batchnorm_bwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_batchnorm_fwd.cpp b/profiler/src/profile_batchnorm_fwd.cpp index 902a1fc423f98880b43e1551f3305854d1a0b9b8..2b3e4eea41b5775d961a44e7b16a4ea79c67ffc9 100644 --- a/profiler/src/profile_batchnorm_fwd.cpp +++ b/profiler/src/profile_batchnorm_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_batchnorm_infer.cpp b/profiler/src/profile_batchnorm_infer.cpp index 92c16859c1aa59ef814fa87166d3d27c9b43a3a2..f1c19bc36e8aded070b3ba843d9fa7007f37a55e 100644 --- a/profiler/src/profile_batchnorm_infer.cpp +++ b/profiler/src/profile_batchnorm_infer.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_contraction_bilinear.cpp b/profiler/src/profile_contraction_bilinear.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6ed184120477c50d0adcb97e997038ad39945233 --- /dev/null +++ b/profiler/src/profile_contraction_bilinear.cpp @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "profiler/profile_contraction_impl.hpp" +#include "profiler/profile_contraction_utils.hpp" +#include "profiler_operation_registry.hpp" + +#define OP_NAME "contraction_bilinear" +#define OP_DESC "CONTRACTION+Bilinear" + +static void print_helper_msg() +{ + std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: fp32; 1: f64)\n" + << "arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + " + "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" + << " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + " + "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" + << " 2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + " + "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" + << " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + " + "D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n" + << "arg4: verification (0: no; 1: yes)\n" + << "arg5: initialization (0: no init; 1: integer value; 2: decimal " + << "value)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0: no, 1: yes)\n" + << "arg8 and arg9: alpha and beta\n" + << "arg10 to 15: M0, M1, N0, N1, K0, K1\n" + << "arg16 to 31: Strides for A, B, D and E (skip for default)\n" + << std::endl; +} + +int profile_contraction_bilinear(int argc, char* argv[]) +{ + const bool default_strides = argc == 16; + + if(argc != 32 && argc != 16) + { + print_helper_msg(); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const ck::index_t init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + const float alpha = std::stof(argv[8]); + const float beta = std::stof(argv[9]); + + std::vector M; + std::vector N; + std::vector K; + const ck::index_t dims_arg_num = 10; + collect_index_params(argv, M, dims_arg_num, 2); + collect_index_params(argv, N, dims_arg_num + 2, 2); + collect_index_params(argv, K, dims_arg_num + 4, 2); + + std::vector StridesA; + std::vector StridesB; + std::vector StridesE; + std::vector StridesD; + if(!default_strides) + { + collect_index_params(argv, StridesA, dims_arg_num + 6, 4); + collect_index_params(argv, StridesB, dims_arg_num + 10, 4); + collect_index_params(argv, StridesE, dims_arg_num + 14, 4); + collect_index_params(argv, StridesD, dims_arg_num + 18, 4); + } + + using F32 = float; + using F64 = double; + + auto profile = [&](auto a_layout, auto b_layout, auto cde_layout, auto type) { + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CDELayout = decltype(cde_layout); + + using DataType = decltype(type); + + if(default_strides) + { + assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]}); + assign_default_strides(b_layout, StridesB, {K[0], K[1], N[0], N[1]}); + assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]}); + assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]}); + } + bool pass = ck::profiler::profile_contraction_impl, + Bilinear>(do_verification, + init_method, + do_log, + time_kernel, + Bilinear{alpha, beta}, + M, + N, + K, + StridesA, + StridesB, + StridesE, + StridesD); + + return pass; + }; + + if(data_type == ContractionDataType::F32_F32_F32_F32 && + layout == ContractionMatrixLayout::MK_KN_MN_MN) + { + return profile(Row{}, Row{}, Row{}, F32{}); + } + else if(data_type == ContractionDataType::F32_F32_F32_F32 && + layout == ContractionMatrixLayout::MK_NK_MN_MN) + { + return profile(Row{}, Col{}, Row{}, F32{}); + } + else if(data_type == ContractionDataType::F32_F32_F32_F32 && + layout == ContractionMatrixLayout::KM_KN_MN_MN) + { + return profile(Col{}, Row{}, Row{}, F32{}); + } + else if(data_type == ContractionDataType::F32_F32_F32_F32 && + layout == ContractionMatrixLayout::KM_NK_MN_MN) + { + return profile(Col{}, Col{}, Row{}, F32{}); + } + else if(data_type == ContractionDataType::F64_F64_F64_F64 && + layout == ContractionMatrixLayout::MK_KN_MN_MN) + { + return profile(Row{}, Row{}, Row{}, F64{}); + } + else if(data_type == ContractionDataType::F64_F64_F64_F64 && + layout == ContractionMatrixLayout::MK_NK_MN_MN) + { + return profile(Row{}, Col{}, Row{}, F64{}); + } + else if(data_type == ContractionDataType::F64_F64_F64_F64 && + layout == ContractionMatrixLayout::KM_KN_MN_MN) + { + return profile(Col{}, Row{}, Row{}, F64{}); + } + else if(data_type == ContractionDataType::F64_F64_F64_F64 && + layout == ContractionMatrixLayout::KM_NK_MN_MN) + { + return profile(Col{}, Col{}, Row{}, F64{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction_bilinear); diff --git a/profiler/src/profile_contraction_scale.cpp b/profiler/src/profile_contraction_scale.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6784b916f608811baeb361c5ea303c3d2eb145ca --- /dev/null +++ b/profiler/src/profile_contraction_scale.cpp @@ -0,0 +1,162 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "profiler/profile_contraction_impl.hpp" +#include "profiler/profile_contraction_utils.hpp" +#include "profiler_operation_registry.hpp" + +#define OP_NAME "contraction_scale" +#define OP_DESC "CONTRACTION+Scale" + +static void print_helper_msg() +{ + std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: fp32; 1: f64)\n" + << "arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + " + "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" + << " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + " + "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" + << " 2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + " + "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" + << " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + " + "D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n" + << "arg4: verification (0: no; 1: yes)\n" + << "arg5: initialization (0: no init; 1: integer value; 2: decimal " + << "value)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0: no, 1: yes)\n" + << "arg8: alpha\n" + << "arg9 to 14: M0, M1, N0, N1, K0, K1\n" + << "arg15 to 30: Strides for A, B, D and E (skip for default)\n" + << std::endl; +} + +int profile_contraction_scale(int argc, char* argv[]) +{ + const bool default_strides = argc == 15; + + if(argc != 31 && argc != 15) + { + print_helper_msg(); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const ck::index_t init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + const float alpha = std::stof(argv[8]); + + std::vector M; + std::vector N; + std::vector K; + const ck::index_t dims_arg_num = 9; + collect_index_params(argv, M, dims_arg_num, 2); + collect_index_params(argv, N, dims_arg_num + 2, 2); + collect_index_params(argv, K, dims_arg_num + 4, 2); + + std::vector StridesA; + std::vector StridesB; + std::vector StridesE; + std::vector StridesD; + if(!default_strides) + { + collect_index_params(argv, StridesA, dims_arg_num + 6, 4); + collect_index_params(argv, StridesB, dims_arg_num + 10, 4); + collect_index_params(argv, StridesE, dims_arg_num + 14, 4); + collect_index_params(argv, StridesD, dims_arg_num + 18, 4); + } + + using F32 = float; + using F64 = double; + + auto profile = [&](auto a_layout, auto b_layout, auto cde_layout, auto type) { + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CDELayout = decltype(cde_layout); + + using DataType = decltype(type); + + if(default_strides) + { + assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]}); + assign_default_strides(b_layout, StridesB, {K[0], K[1], N[0], N[1]}); + assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]}); + assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]}); + } + + bool pass = ck::profiler:: + profile_contraction_impl, Scale>( + do_verification, + init_method, + do_log, + time_kernel, + Scale{alpha}, + M, + N, + K, + StridesA, + StridesB, + StridesE, + StridesD); + + return pass; + }; + + if(data_type == ContractionDataType::F32_F32_F32_F32 && + layout == ContractionMatrixLayout::MK_KN_MN_MN) + { + return profile(Row{}, Row{}, Row{}, F32{}); + } + else if(data_type == ContractionDataType::F32_F32_F32_F32 && + layout == ContractionMatrixLayout::MK_NK_MN_MN) + { + return profile(Row{}, Col{}, Row{}, F32{}); + } + else if(data_type == ContractionDataType::F32_F32_F32_F32 && + layout == ContractionMatrixLayout::KM_KN_MN_MN) + { + return profile(Col{}, Row{}, Row{}, F32{}); + } + else if(data_type == ContractionDataType::F32_F32_F32_F32 && + layout == ContractionMatrixLayout::KM_NK_MN_MN) + { + return profile(Col{}, Col{}, Row{}, F32{}); + } + else if(data_type == ContractionDataType::F64_F64_F64_F64 && + layout == ContractionMatrixLayout::MK_KN_MN_MN) + { + return profile(Row{}, Row{}, Row{}, F64{}); + } + else if(data_type == ContractionDataType::F64_F64_F64_F64 && + layout == ContractionMatrixLayout::MK_NK_MN_MN) + { + return profile(Row{}, Col{}, Row{}, F64{}); + } + else if(data_type == ContractionDataType::F64_F64_F64_F64 && + layout == ContractionMatrixLayout::KM_KN_MN_MN) + { + return profile(Col{}, Row{}, Row{}, F64{}); + } + else if(data_type == ContractionDataType::F64_F64_F64_F64 && + layout == ContractionMatrixLayout::KM_NK_MN_MN) + { + return profile(Col{}, Col{}, Row{}, F64{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction_scale); diff --git a/profiler/src/profile_conv_bwd_data.cpp b/profiler/src/profile_conv_bwd_data.cpp index 9241ead738ef8a8a67d8682d743f6c9bd0faf640..465abacc4a024f7a8ed8fc0ac5feb0c0e6b13fc0 100644 --- a/profiler/src/profile_conv_bwd_data.cpp +++ b/profiler/src/profile_conv_bwd_data.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_conv_fwd.cpp b/profiler/src/profile_conv_fwd.cpp index b57ee7fd94261f4437d1b62f0b6450ea74522db9..701999d8a9fcd3718fc8d62e9b77f4f49159d474 100644 --- a/profiler/src/profile_conv_fwd.cpp +++ b/profiler/src/profile_conv_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_conv_fwd_bias_relu.cpp b/profiler/src/profile_conv_fwd_bias_relu.cpp index b44007cde4742efcec188f0b620e2cf83295f5b6..31055ec1d1530ff5fe9c89e89957fdd6e83f7efc 100644 --- a/profiler/src/profile_conv_fwd_bias_relu.cpp +++ b/profiler/src/profile_conv_fwd_bias_relu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_conv_fwd_bias_relu_add.cpp b/profiler/src/profile_conv_fwd_bias_relu_add.cpp index 408dd02f78dd6241b5f0a67ae7aac4420c891943..8c2439a0c74be78b2d6ffcfdcce033c9ec0c7df5 100644 --- a/profiler/src/profile_conv_fwd_bias_relu_add.cpp +++ b/profiler/src/profile_conv_fwd_bias_relu_add.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_gemm.cpp b/profiler/src/profile_gemm.cpp index 61bae6ae70ebcf965618a100c3f553de73182118..b3587ea98d8555f86344621f3c8d48967b98ef46 100644 --- a/profiler/src/profile_gemm.cpp +++ b/profiler/src/profile_gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_gemm_add_add_fastgelu.cpp b/profiler/src/profile_gemm_add_add_fastgelu.cpp index c3c0fb7b67daf131ab53194bbf28a12aeedf62aa..8af3768a48aa9905700b55ab73e83777f20f8223 100644 --- a/profiler/src/profile_gemm_add_add_fastgelu.cpp +++ b/profiler/src/profile_gemm_add_add_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_gemm_add_fastgelu.cpp b/profiler/src/profile_gemm_add_fastgelu.cpp index 380b25a614c58d42209135218cca331de3bebe33..a09bb8340d3268ed642b1cd15f318f9458867c99 100644 --- a/profiler/src/profile_gemm_add_fastgelu.cpp +++ b/profiler/src/profile_gemm_add_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_gemm_add_multiply.cpp b/profiler/src/profile_gemm_add_multiply.cpp index 7d6fead402f22a92d3b2c12b8d612c22582504ee..560467c264f2bcc8e406ed02de742c06942e27e0 100644 --- a/profiler/src/profile_gemm_add_multiply.cpp +++ b/profiler/src/profile_gemm_add_multiply.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_gemm_add_relu_add_layernorm.cpp b/profiler/src/profile_gemm_add_relu_add_layernorm.cpp index 5cbc3d21f8a335aaaf0353308a2fc996e210b520..558d255ce110565e90f6eff5fae86ce6bf6447bb 100644 --- a/profiler/src/profile_gemm_add_relu_add_layernorm.cpp +++ b/profiler/src/profile_gemm_add_relu_add_layernorm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_gemm_bias_add_reduce.cpp b/profiler/src/profile_gemm_bias_add_reduce.cpp index 6d86db08223a771ddc4373c014a6a2a42f92285b..76daffbc67cb95ac19d52a6791d90816a5e6656b 100644 --- a/profiler/src/profile_gemm_bias_add_reduce.cpp +++ b/profiler/src/profile_gemm_bias_add_reduce.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_gemm_bilinear.cpp b/profiler/src/profile_gemm_bilinear.cpp index 3480014ba6e6a2088c1ee1a221eb99eee6191241..a1a48616b42cac2f0babf59955a68cc49c323ac3 100644 --- a/profiler/src/profile_gemm_bilinear.cpp +++ b/profiler/src/profile_gemm_bilinear.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_gemm_fastgelu.cpp b/profiler/src/profile_gemm_fastgelu.cpp index 2a137224cb096f4bf2968f06be13d71d44614794..93573002ef1c367a00d42dc2ea0fe99fca98086a 100644 --- a/profiler/src/profile_gemm_fastgelu.cpp +++ b/profiler/src/profile_gemm_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_gemm_reduce.cpp b/profiler/src/profile_gemm_reduce.cpp index 395bf0627e43617d04d30a903a1f3cbdecf4a557..48f6f5eb49da151136600c2d03b99d959d51565f 100644 --- a/profiler/src/profile_gemm_reduce.cpp +++ b/profiler/src/profile_gemm_reduce.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_gemm_splitk.cpp b/profiler/src/profile_gemm_splitk.cpp index f636ce718c669feb23a60fb1853a1219fea3da55..cc2da73cb72704ffb2c1283976f6307b2e89908c 100644 --- a/profiler/src/profile_gemm_splitk.cpp +++ b/profiler/src/profile_gemm_splitk.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_grouped_conv_bwd_data.cpp b/profiler/src/profile_grouped_conv_bwd_data.cpp new file mode 100644 index 0000000000000000000000000000000000000000..351bb72c8153ffaea6359b38d344fec5efe043bc --- /dev/null +++ b/profiler/src/profile_grouped_conv_bwd_data.cpp @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_grouped_conv_bwd_data_impl.hpp" +#include "profiler_operation_registry.hpp" + +namespace { + +enum struct ConvLayout +{ + GNHWC_GKYXC_GNHWK, // 0 + NHWGC_GKYXC_NHWGK, // 1 +}; + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 +}; + +#define OP_NAME "grouped_conv_bwd_data" +#define OP_DESC "Grouped Convolution Backward Data" + +static void print_helper_msg() +{ + std::cout + // clang-format off + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: Output fp32, Weight fp32, Input fp32\n" + << " 1: Output fp16, Weight fp16, Input fp16\n" + << " 2: Output bf16, Weight bf16, Input bf16\n" + << "arg3: tensor layout (0: Output[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Input[G, N, Ho, Wo, K]\n" + << " 1: Output[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Input[N, Ho, Wo, G, K])\n" + << "arg4: verification (0: no, 1: yes)\n" + << "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0: no, 1: yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; + // clang-format on +} + +} // namespace + +int profile_grouped_conv_bwd_data(int argc, char* argv[]) +{ + // 8 for control, 1 for num_dim_spatial + if(argc < 9) + { + print_helper_msg(); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + const int num_dim_spatial = std::stoi(argv[8]); + + // 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + if(argc != 8 + 1 + 4 + 6 * num_dim_spatial) + { + print_helper_msg(); + return 1; + } + + const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 9, argv); + + using F32 = float; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + + using GNHWC = ck::tensor_layout::convolution::GNHWC; + using NHWGC = ck::tensor_layout::convolution::NHWGC; + + using GKYXC = ck::tensor_layout::convolution::GKYXC; + + using GNHWK = ck::tensor_layout::convolution::GNHWK; + using NHWGK = ck::tensor_layout::convolution::NHWGK; + + constexpr auto I2 = ck::Number<2>{}; + + auto profile = [&](auto num_dim_spatial_tmp, + auto out_layout, + auto wei_layout, + auto in_layout, + auto wei_type, + auto out_type, + auto in_type) { + constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; + + using OutLayout = decltype(out_layout); + using WeiLayout = decltype(wei_layout); + using InLayout = decltype(in_layout); + + using OutDataType = decltype(out_type); + using WeiDataType = decltype(wei_type); + using InDataType = decltype(in_type); + + bool pass = ck::profiler::profile_grouped_conv_bwd_data_impl( + do_verification, init_method, do_log, time_kernel, params); + + return pass ? 0 : 1; + }; + + // GNHWC_GKYXC_GNHWK + if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, BF16{}, BF16{}, BF16{}); + } + } + // NHWGC_GKYXC_NHWGK + else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, BF16{}, BF16{}, BF16{}); + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_conv_bwd_data); diff --git a/profiler/src/profile_grouped_conv_bwd_weight.cpp b/profiler/src/profile_grouped_conv_bwd_weight.cpp index dfd8a099f5408f76fd51a9c69a988591edf5348f..7a062ed5197109474e6bdad07a33f8b88c5b12f8 100644 --- a/profiler/src/profile_grouped_conv_bwd_weight.cpp +++ b/profiler/src/profile_grouped_conv_bwd_weight.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index 9ff3c15af05a5f67988ef4842e00e2671289bb1b..d0b424cde64249aff7da0646a6bb11e83ccecf39 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_grouped_gemm.cpp b/profiler/src/profile_grouped_gemm.cpp index 871b2edfd5f829e1fa88d3ebcc399adc183542df..d023db54de47cdb4464148985547a80ad441190b 100644 --- a/profiler/src/profile_grouped_gemm.cpp +++ b/profiler/src/profile_grouped_gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -52,20 +52,24 @@ std::vector argToIntArray(char* input) int profile_grouped_gemm(int argc, char* argv[]) { - if(!(argc == 14)) + if(argc < 14) { - printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); - printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); - printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); - printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); - printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); - printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); - printf("arg4: verification (0: no; 1: yes)\n"); - printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); - printf("arg6: print tensor value (0: no; 1: yes)\n"); - printf("arg7: time kernel (0=n0, 1=yes)\n"); - printf("arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " - "64,64 64,64 128,128)\n"); + std::cout + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n" + << "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n" + << " 1: A[m, k] * B[n, k] = C[m, n];\n" + << " 2: A[k, m] * B[k, n] = C[m, n];\n" + << " 3: A[k, m] * B[n, k] = C[m, n])\n" + << "arg4: verification (0: no; 1: yes)\n" + << "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0=n0, 1=yes)\n" + << "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " + "64,64 64,64 128,128)\n" + << "arg15: kbatch value (default 4)\n" + << std::endl; + exit(1); } @@ -83,6 +87,7 @@ int profile_grouped_gemm(int argc, char* argv[]) const auto StrideAs = argToIntArray(argv[11]); const auto StrideBs = argToIntArray(argv[12]); const auto StrideCs = argToIntArray(argv[13]); + const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1; if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) { @@ -101,7 +106,8 @@ int profile_grouped_gemm(int argc, char* argv[]) Ks, StrideAs, StrideBs, - StrideCs); + StrideCs, + kbatch); } else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) { @@ -120,7 +126,8 @@ int profile_grouped_gemm(int argc, char* argv[]) Ks, StrideAs, StrideBs, - StrideCs); + StrideCs, + kbatch); } else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) { @@ -139,7 +146,8 @@ int profile_grouped_gemm(int argc, char* argv[]) Ks, StrideAs, StrideBs, - StrideCs); + StrideCs, + kbatch); } else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) { @@ -158,7 +166,8 @@ int profile_grouped_gemm(int argc, char* argv[]) Ks, StrideAs, StrideBs, - StrideCs); + StrideCs, + kbatch); } else { diff --git a/profiler/src/profile_grouped_gemm_fastgelu.cpp b/profiler/src/profile_grouped_gemm_fastgelu.cpp index 9b6142f01559542455b002f5254ca848ec7867a5..50ecf25caebdedb385d9ab523abe5f73f1afb8d5 100644 --- a/profiler/src/profile_grouped_gemm_fastgelu.cpp +++ b/profiler/src/profile_grouped_gemm_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_groupnorm.cpp b/profiler/src/profile_groupnorm.cpp index 2741f52717a5b7e370be2ddf3420273fc78df0e5..d55784ff0ad7274fce9bbf66b26a508f5994a578 100644 --- a/profiler/src/profile_groupnorm.cpp +++ b/profiler/src/profile_groupnorm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -64,7 +64,7 @@ int profile_groupnorm(int argc, char* argv[]) ck::DataTypeEnum data_type = ck::DataTypeEnum::Half; bool do_verification = false; int init_method = 0; - bool do_log = 0; + bool do_log = false; bool time_kernel = 1; std::vector length = {64, 16, 16, 32, 40}; diff --git a/profiler/src/profile_layernorm.cpp b/profiler/src/profile_layernorm.cpp index e93fc2dbd2bf004389eed968e0518a2cdb1ea4f6..7bf210e67837fa8fe0fe8982c40a1793344823a0 100644 --- a/profiler/src/profile_layernorm.cpp +++ b/profiler/src/profile_layernorm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_max_pool3d_fwd.cpp b/profiler/src/profile_max_pool3d_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cf6db2cfc91e1b7d52af6cc7e9f9d43f4051599e --- /dev/null +++ b/profiler/src/profile_max_pool3d_fwd.cpp @@ -0,0 +1,168 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "profiler/data_type_enum.hpp" +#include "profiler/profile_pool3d_fwd_impl.hpp" +#include "profiler_operation_registry.hpp" + +using ck::index_t; + +struct maxPoolFwdArgParser +{ + std::unordered_map> long_opts = { + {"length", {}}, {"wsize", {}}, {"wstride", {}}, {"pad1", {}}, {"pad2", {}}}; + + bool parse_opt(int argc, char* argv[], const std::string& key, int i) + { + if(std::string("--") + key == argv[i]) + { + int pos = i; + while(++i < argc && argv[i][0] != '-') {} + int end = i; + for(int j = pos + 1; j < end; j++) + { + long_opts[key].push_back(std::stoi(argv[j])); + } + return true; + } + return false; + } + + void operator()(int argc, char* argv[]) + { + for(auto& kv : long_opts) + { + for(int i = 1; i < argc; i++) + { + if(parse_opt(argc, argv, kv.first, i)) + break; + } + } + } +}; + +void print_help_max_pool3d_fwd() +{ + std::cout << "arg1: data type (0: fp16; 1: fp32)\n" + << "arg2: verification (0: no; 1: yes)\n" + << "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg4: print tensor value (0: no; 1: yes)\n" + << "arg5: time kernel (0=no, 1=yes)\n" + << "arg6: return index (0=no, 1=yes)\n" + << "--length: input tensor length for NCDHW(e.g, --length 2 32 30 30 30) \n" + << "--wsize: window size for ZYX (e.g, --wsize 2 2 2) \n" + << "--wstride: window stride for DHW (e.g, --wstride 2 2 2) \n" + << "--pad1: left side of padding in DHW (e.g, --pad1 1 1 1) \n" + << "--pad2: right side of padding in DHW (e.g, --pad2 1 1 1) \n" + << "eg: ckProfiler max_pool3d_fwd 0 1 2 0 1 0 --length 2 32 30 30 30 --wsize 2 2 2 " + "--wstride 2 2 2 --pad1 1 1 1 --pad2 1 1 1" + << std::endl; +} + +int profile_max_pool3d_fwd(int argc, char* argv[]) +{ + ck::DataTypeEnum data_type = ck::DataTypeEnum::Half; + bool do_verification = true; + int init_method = 0; + bool do_log = false; + bool time_kernel = true; + bool return_index = false; + + std::vector in_length = {2, 32, 30, 30, 30}; + std::vector wsize = {2, 2, 2}; + std::vector wstride = {2, 2, 2}; + std::vector pad1 = {1, 1, 1}; + std::vector pad2 = {1, 1, 1}; + + if(argc != 2 && argc != 30) + { + print_help_max_pool3d_fwd(); + return 0; + } + else if(argc == 30) + { + data_type = static_cast(std::stoi(argv[2])); + do_verification = std::stoi(argv[3]); + init_method = std::stoi(argv[4]); + do_log = std::stoi(argv[5]); + time_kernel = std::stoi(argv[6]); + return_index = std::stoi(argv[7]); + + // parse the long options + maxPoolFwdArgParser arg_parser; + arg_parser(argc, argv); + in_length = arg_parser.long_opts["length"]; + wsize = arg_parser.long_opts["wsize"]; + wstride = arg_parser.long_opts["wstride"]; + pad1 = arg_parser.long_opts["pad1"]; + pad2 = arg_parser.long_opts["pad2"]; + } + + using F16 = ck::half_t; + using F32 = float; + using I32 = int32_t; + constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; + + if(data_type == ck::DataTypeEnum::Half) + { + if(return_index) + ck::profiler::profile_pool3d_fwd_impl( + do_verification, + init_method, + do_log, + time_kernel, + in_length, + wsize, + wstride, + pad1, + pad2); + else + ck::profiler::profile_pool3d_fwd_impl( + do_verification, + init_method, + do_log, + time_kernel, + in_length, + wsize, + wstride, + pad1, + pad2); + } + else if(data_type == ck::DataTypeEnum::Float) + { + if(return_index) + ck::profiler::profile_pool3d_fwd_impl( + do_verification, + init_method, + do_log, + time_kernel, + in_length, + wsize, + wstride, + pad1, + pad2); + else + ck::profiler::profile_pool3d_fwd_impl( + do_verification, + init_method, + do_log, + time_kernel, + in_length, + wsize, + wstride, + pad1, + pad2); + } + else + { + throw std::runtime_error("not implemented yet"); + } + + return 0; +} + +REGISTER_PROFILER_OPERATION("max_pool3d_fwd", "max_pool3d fwd", profile_max_pool3d_fwd); diff --git a/profiler/src/profile_reduce.cpp b/profiler/src/profile_reduce.cpp index 6925371858ee913dcb5d2b529c3e350a6eec5f4f..e4af5680a5d8a192edfc380d18fbe0a2c595cbbb 100644 --- a/profiler/src/profile_reduce.cpp +++ b/profiler/src/profile_reduce.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_softmax.cpp b/profiler/src/profile_softmax.cpp index 78b64dda7d70ec0a6cf8487562a9782c261f58c7..dfe8d95c904e297fcd2f7b83a166fc61c8950233 100644 --- a/profiler/src/profile_softmax.cpp +++ b/profiler/src/profile_softmax.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -92,27 +92,76 @@ int profile_softmax(int argc, char* argv[]) { if(data_type == SoftmaxDataType::F16_F16) { - ck::profiler::profile_softmax_impl(do_verification, - init_method, - do_log, - time_kernel, - length, - stride, - reduce, - double(alpha), - double(beta)); + if(reduce.size() == 1) + ck::profiler::profile_softmax_impl( + do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else if(reduce.size() == 2) + ck::profiler::profile_softmax_impl( + do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else if(reduce.size() == 3) + ck::profiler::profile_softmax_impl( + do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else + throw std::runtime_error("invalid number of dimensions to reduce"); } else if(data_type == SoftmaxDataType::F32_F32) { - ck::profiler::profile_softmax_impl(do_verification, - init_method, - do_log, - time_kernel, - length, - stride, - reduce, - double(alpha), - double(beta)); + if(reduce.size() == 1) + ck::profiler::profile_softmax_impl(do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else if(reduce.size() == 2) + ck::profiler::profile_softmax_impl(do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else if(reduce.size() == 3) + ck::profiler::profile_softmax_impl(do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else + throw std::runtime_error("invalid number of dimensions to reduce"); } else { @@ -124,27 +173,97 @@ int profile_softmax(int argc, char* argv[]) { if(data_type == SoftmaxDataType::F16_F16) { - ck::profiler::profile_softmax_impl(do_verification, - init_method, - do_log, - time_kernel, - length, - stride, - reduce, - double(alpha), - double(beta)); + if(reduce.size() == 1) + ck::profiler::profile_softmax_impl( + do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else if(reduce.size() == 2) + ck::profiler::profile_softmax_impl( + do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else if(reduce.size() == 3) + ck::profiler::profile_softmax_impl( + do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else if(reduce.size() == 4) + ck::profiler::profile_softmax_impl( + do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else + throw std::runtime_error("invalid number of dimensions to reduce"); } else if(data_type == SoftmaxDataType::F32_F32) { - ck::profiler::profile_softmax_impl(do_verification, - init_method, - do_log, - time_kernel, - length, - stride, - reduce, - double(alpha), - double(beta)); + if(reduce.size() == 1) + ck::profiler::profile_softmax_impl(do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else if(reduce.size() == 2) + ck::profiler::profile_softmax_impl(do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else if(reduce.size() == 3) + ck::profiler::profile_softmax_impl(do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else if(reduce.size() == 4) + ck::profiler::profile_softmax_impl(do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else + throw std::runtime_error("invalid number of dimensions to reduce"); } else { diff --git a/profiler/src/profiler.cpp b/profiler/src/profiler.cpp index 080117e390c4df1bc14195649382267bcccc18f7..0f528c008f1d17afd42139d572d4f1cf1963fadd 100644 --- a/profiler/src/profiler.cpp +++ b/profiler/src/profiler.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profiler_operation_registry.hpp b/profiler/src/profiler_operation_registry.hpp index 91ff291233066527609133dc155e67388119385a..276b7b38dcd28ee935c108819cd0102b2bec4862 100644 --- a/profiler/src/profiler_operation_registry.hpp +++ b/profiler/src/profiler_operation_registry.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 375ff493104feee709b1130ce577229c302ed878..426f68d4432e19f378c543e1edb90364f97e54af 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -8,13 +8,12 @@ MY_PROJECT_SOURCE=$1 cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ --D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -mcumode \ --mno-wavefrontsize64 -Wno-gnu-line-marker -save-temps=$PWD" \ +-D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker \ +-save-temps=$PWD" \ -D CMAKE_BUILD_TYPE=Release \ -D BUILD_DEV=ON \ --D GPU_TARGETS="gfx908;gfx90a" \ +-D GPU_TARGETS="gfx908;gfx90a;gfx940" \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D USE_BITINT_EXTENSION_INT4=OFF \ ${MY_PROJECT_SOURCE} -#-D AMDGPU_TARGETS=gfx90a;gfx908 diff --git a/script/cmake-ck-release.sh b/script/cmake-ck-release.sh index 268b1ebf9b7fa1743ef06b8cf3073dc44c0a7bc1..787eabbf96bc1029d54ff592bee8eaca6e214b89 100755 --- a/script/cmake-ck-release.sh +++ b/script/cmake-ck-release.sh @@ -11,9 +11,8 @@ cmake -D CMAKE_CXX_FLAGS="-O3" \ -D CMAKE_BUILD_TYPE=Release \ -D BUILD_DEV=OFF \ --D GPU_TARGETS="gfx908;gfx90a" \ +-D GPU_TARGETS="gfx908;gfx90a;gfx940" \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D USE_BITINT_EXTENSION_INT4=OFF \ ${MY_PROJECT_SOURCE} -#-D AMDGPU_TARGETS=gfx90a;gfx908 diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 6f43e523559539dc17cb80e07f9863db14238a7d..ad08e94702ac7f3757d6269469acc94f5a98d0da 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -56,6 +56,10 @@ add_subdirectory(normalization) add_subdirectory(data_type) add_subdirectory(elementwise_normalization) add_subdirectory(batchnorm) +add_subdirectory(contraction) +add_subdirectory(pool_fwd) +add_subdirectory(batched_gemm_multi_d) +add_subdirectory(grouped_convnd_bwd_data) if(GPU_TARGETS MATCHES "gfx1100") add_subdirectory(wmma_op) endif() diff --git a/test/batched_gemm/CMakeLists.txt b/test/batched_gemm/CMakeLists.txt index 0574f98e872dcb436b14df3f536f06a787d28c8d..12142c33e71cc3751d4ac79cf739d02000da8731 100644 --- a/test/batched_gemm/CMakeLists.txt +++ b/test/batched_gemm/CMakeLists.txt @@ -1,15 +1,22 @@ -add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp) -target_link_libraries(test_batched_gemm_fp16 PRIVATE utility) -target_link_libraries(test_batched_gemm_fp16 PRIVATE device_batched_gemm_instance) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp) + target_link_libraries(test_batched_gemm_fp16 PRIVATE utility) + target_link_libraries(test_batched_gemm_fp16 PRIVATE device_batched_gemm_instance) -add_test_executable(test_batched_gemm_fp32 batched_gemm_fp32.cpp) -target_link_libraries(test_batched_gemm_fp32 PRIVATE utility) -target_link_libraries(test_batched_gemm_fp32 PRIVATE device_batched_gemm_instance) + add_test_executable(test_batched_gemm_fp32 batched_gemm_fp32.cpp) + target_link_libraries(test_batched_gemm_fp32 PRIVATE utility) + target_link_libraries(test_batched_gemm_fp32 PRIVATE device_batched_gemm_instance) -add_test_executable(test_batched_gemm_bf16 batched_gemm_bf16.cpp) -target_link_libraries(test_batched_gemm_bf16 PRIVATE utility) -target_link_libraries(test_batched_gemm_bf16 PRIVATE device_batched_gemm_instance) + add_test_executable(test_batched_gemm_bf16 batched_gemm_bf16.cpp) + target_link_libraries(test_batched_gemm_bf16 PRIVATE utility) + target_link_libraries(test_batched_gemm_bf16 PRIVATE device_batched_gemm_instance) -add_test_executable(test_batched_gemm_int8 batched_gemm_int8.cpp) -target_link_libraries(test_batched_gemm_int8 PRIVATE utility) -target_link_libraries(test_batched_gemm_int8 PRIVATE device_batched_gemm_instance) + add_test_executable(test_batched_gemm_int8 batched_gemm_int8.cpp) + target_link_libraries(test_batched_gemm_int8 PRIVATE utility) + target_link_libraries(test_batched_gemm_int8 PRIVATE device_batched_gemm_instance) + set(target 1) + endif() +endforeach() \ No newline at end of file diff --git a/test/batched_gemm/batched_gemm_bf16.cpp b/test/batched_gemm/batched_gemm_bf16.cpp index 78be540627850426a075e2bde1516410ef070756..5d12a1e956d469d0a1be7301c993bfd981ebafc3 100644 --- a/test/batched_gemm/batched_gemm_bf16.cpp +++ b/test/batched_gemm/batched_gemm_bf16.cpp @@ -1,10 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include "profiler/profile_batched_gemm_impl.hpp" +#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp" + namespace { using ADataType = ck::bhalf_t; using BDataType = ck::bhalf_t; @@ -12,6 +14,8 @@ using CDataType = ck::bhalf_t; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; } // namespace int main() @@ -23,21 +27,87 @@ int main() bool pass = true; - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); + using namespace ck::tensor_operation::device; + + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); std::cout << "test BatchedGEMM bf16: " << (pass ? "Pass" : "Fail") << std::endl; return pass ? 0 : 1; diff --git a/test/batched_gemm/batched_gemm_fp16.cpp b/test/batched_gemm/batched_gemm_fp16.cpp index 6cbbedf6774d2f90f85e5128c5bae8fbe912dc8e..a2b61d951a724bce18836c229ca7ef865336d849 100644 --- a/test/batched_gemm/batched_gemm_fp16.cpp +++ b/test/batched_gemm/batched_gemm_fp16.cpp @@ -1,10 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include "profiler/profile_batched_gemm_impl.hpp" +#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp" + namespace { using ADataType = ck::half_t; using BDataType = ck::half_t; @@ -12,6 +14,8 @@ using CDataType = ck::half_t; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; } // namespace int main() @@ -23,21 +27,87 @@ int main() bool pass = true; - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); + using namespace ck::tensor_operation::device; + + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); std::cout << "test BatchedGEMM fp16: " << (pass ? "Pass" : "Fail") << std::endl; return pass ? 0 : 1; diff --git a/test/batched_gemm/batched_gemm_fp32.cpp b/test/batched_gemm/batched_gemm_fp32.cpp index c9e565e264b493d6f7f1b248fe6d307150a4c989..2b18d166e68f96e53670b4783038228b22b41cec 100644 --- a/test/batched_gemm/batched_gemm_fp32.cpp +++ b/test/batched_gemm/batched_gemm_fp32.cpp @@ -1,10 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include "profiler/profile_batched_gemm_impl.hpp" +#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp" + namespace { using ADataType = float; using BDataType = float; @@ -12,6 +14,8 @@ using CDataType = float; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; } // namespace int main() @@ -23,21 +27,87 @@ int main() bool pass = true; - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); + using namespace ck::tensor_operation::device; + + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); std::cout << "test BatchedGEMM fp32: " << (pass ? "Pass" : "Fail") << std::endl; return pass ? 0 : 1; diff --git a/test/batched_gemm/batched_gemm_int8.cpp b/test/batched_gemm/batched_gemm_int8.cpp index 4da941a5766bc6decc5adc938b4ce39da5d8a6db..f607eaa84b996a0c28394647021953c5da6017a0 100644 --- a/test/batched_gemm/batched_gemm_int8.cpp +++ b/test/batched_gemm/batched_gemm_int8.cpp @@ -1,10 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include "profiler/profile_batched_gemm_impl.hpp" +#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp" + namespace { using ADataType = int8_t; using BDataType = int8_t; @@ -12,6 +14,8 @@ using CDataType = int8_t; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; } // namespace int main() @@ -23,21 +27,87 @@ int main() bool pass = true; - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); + using namespace ck::tensor_operation::device; + + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); std::cout << "test BatchedGEMM int8: " << (pass ? "Pass" : "Fail") << std::endl; return pass ? 0 : 1; diff --git a/test/batched_gemm_gemm/CMakeLists.txt b/test/batched_gemm_gemm/CMakeLists.txt index 386809717f242c5a23dd3c49437b4ebf055243eb..858efd837980658f20e0929ad93bc97e92ba7909 100644 --- a/test/batched_gemm_gemm/CMakeLists.txt +++ b/test/batched_gemm_gemm/CMakeLists.txt @@ -1,5 +1,11 @@ -add_custom_target(test_batched_gemm_gemm) - -add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp) -target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance) -add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16) \ No newline at end of file +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(test_batched_gemm_gemm) + add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp) + target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance) + add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16) + set(target 1) + endif() +endforeach() \ No newline at end of file diff --git a/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp b/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp index aa113de219437e900b2de8739ddaffb2caf3cefa..1a8d5c2e55906fb77cfc3fef7d6c985c37eefefb 100644 --- a/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp +++ b/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "test_batched_gemm_gemm_util.hpp" diff --git a/test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp b/test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp index 53c4d37c44781be976f917362d0c433d1ec1711f..b0fffc466efd7de5f9a8055cb3f7e372bfad2ef7 100644 --- a/test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp +++ b/test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/batched_gemm_multi_d/CMakeLists.txt b/test/batched_gemm_multi_d/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..45a306551f28295a33ac0eab54cdcd04052942cc --- /dev/null +++ b/test/batched_gemm_multi_d/CMakeLists.txt @@ -0,0 +1,5 @@ +# TODO: Enable for gfx90a after complier fix +if(NOT GPU_TARGETS MATCHES "gfx90a") + add_gtest_executable(test_batched_gemm_multi_d test_batched_gemm_multi_d.cpp) + target_link_libraries(test_batched_gemm_multi_d PRIVATE utility device_batched_gemm_multi_d_instance) +endif() diff --git a/test/batched_gemm_multi_d/test_batched_gemm_multi_d.cpp b/test/batched_gemm_multi_d/test_batched_gemm_multi_d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4a8265403434543ec400c235c84505bc873f329c --- /dev/null +++ b/test/batched_gemm_multi_d/test_batched_gemm_multi_d.cpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "profiler/profile_batched_gemm_impl.hpp" +#include "ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp" + +namespace { +using F16 = ck::half_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using Empty_Tuple = ck::Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +template +class TestBatchedGemmMultiD : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = std::tuple_element_t<2, Tuple>; + + static constexpr int M = 512; + static constexpr int N = 256; + static constexpr int K = 128; + static constexpr int BatchCount = 3; + + template + void Run() + { + using namespace ck::tensor_operation::device; + + const bool pass = + ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); + EXPECT_TRUE(pass); + } +}; + +using KernelTypes = ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; +} // namespace + +TYPED_TEST_SUITE(TestBatchedGemmMultiD, KernelTypes); + +TYPED_TEST(TestBatchedGemmMultiD, f16) { this->template Run(); } + +TYPED_TEST(TestBatchedGemmMultiD, int8) { this->template Run(); } diff --git a/test/batched_gemm_reduce/CMakeLists.txt b/test/batched_gemm_reduce/CMakeLists.txt index 4dc0b0825744294d219a39d46ec4f8aec06fdf93..0710f4647a7a8fe1b346ed7fedfa0bd3e3d41354 100644 --- a/test/batched_gemm_reduce/CMakeLists.txt +++ b/test/batched_gemm_reduce/CMakeLists.txt @@ -1,3 +1,10 @@ -add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp) -target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility) -target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE device_batched_gemm_reduce_instance) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp) + target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility) + target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE device_batched_gemm_reduce_instance) + set(target 1) + endif() +endforeach() diff --git a/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp b/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp index b150ce50d166d5a673c6b34c468b7988f3b4d91f..dd2638ce899bee16accfcfb85c402909b2c2a05c 100644 --- a/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp +++ b/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/batched_gemm_softmax_gemm/CMakeLists.txt b/test/batched_gemm_softmax_gemm/CMakeLists.txt index 1ceecefb5f2c0a86e6d951948c84a2286023773f..984cc3c160d04ee40cdb227dfe339b5959a539f9 100644 --- a/test/batched_gemm_softmax_gemm/CMakeLists.txt +++ b/test/batched_gemm_softmax_gemm/CMakeLists.txt @@ -1,5 +1,11 @@ -add_custom_target(test_batched_gemm_softmax_gemm) - -add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp) -target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance) -add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16) \ No newline at end of file +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(test_batched_gemm_softmax_gemm) + add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp) + target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance) + add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16) + set(target 1) + endif() +endforeach() \ No newline at end of file diff --git a/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp b/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp index 5df7769d5f675e906e350738751923d26b6af39d..cb46a995c61c4dc15c9fcbd9c2ca07b850f90069 100644 --- a/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp +++ b/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "test_batched_gemm_softmax_gemm_util.hpp" diff --git a/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp b/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp index 98debe19c3c0b419f46dd08321f43f846536772a..d8ee744c6098843e9e78e21ac59a5f3733f40aa8 100644 --- a/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp +++ b/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt b/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt index 79af2b0d3af17c5e02cc0e0029cc6e1e59668183..4d04be3faa5af1adb2ae37793022ab794a27d9c3 100644 --- a/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt +++ b/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt @@ -1,15 +1,22 @@ -add_custom_target(test_batched_gemm_softmax_gemm_permute) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(test_batched_gemm_softmax_gemm_permute) -add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp) -add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp) -target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) -target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) -add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16) -add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16) + add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp) + add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp) + target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16) -add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp) -add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp) -target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) -target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) -add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16) -add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16) \ No newline at end of file + add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp) + add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp) + target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16) + set(target 1) + endif() +endforeach() \ No newline at end of file diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp index fe65a6fb9681a2a336132cf082cd9424e988e1ed..ef88ce6d81845d9ddb61fc00159899f0d11daa5f 100644 --- a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "test_batched_gemm_bias_softmax_gemm_permute_util.hpp" diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp index 7235cd1b0b669cb1cfad6753289a94ad6a569df3..b38b10d1953091bc7d0739a38fa184d3d51a94f1 100644 --- a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "test_batched_gemm_softmax_gemm_permute_util.hpp" diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp index af5f0efec3853eae68dfbb91f97930a1e8e5727b..d7c39367c8a19290718b315251581f4f3f626add 100644 --- a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp index defe36124056d602f565a20333379fc86a61b6aa..8e0baede113943c0bcb7d6f929217b2be94f0073 100644 --- a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "test_batched_gemm_softmax_gemm_permute_util.hpp" diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp index 293acd60155ace98eaea89a11eb03fa6710fd877..81d404109fddc4321d870d1c3b8c0a2fe309c9c4 100644 --- a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "test_batched_gemm_softmax_gemm_permute_util.hpp" diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_util.hpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_util.hpp index 912bbc91edd51d06c827fc31495d2955bea85a2b..9df03ffd2a11211600872957e9686e105bf8b2f2 100644 --- a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_util.hpp +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_util.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/batchnorm/batchnorm_bwd_rank_4.cpp b/test/batchnorm/batchnorm_bwd_rank_4.cpp index caa7331ea2c3e6dfe57a41272c5f4fba56662e39..a4696cf2a39a774c86d4a68f1a0edf0ce31a3f58 100644 --- a/test/batchnorm/batchnorm_bwd_rank_4.cpp +++ b/test/batchnorm/batchnorm_bwd_rank_4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/batchnorm/batchnorm_fwd_rank_4.cpp b/test/batchnorm/batchnorm_fwd_rank_4.cpp index 13aef7d6bfc9f222e663b951a2f9d44e81fc93bc..9b6fbd0f662648426aad3f9c316604510ff6e1b1 100644 --- a/test/batchnorm/batchnorm_fwd_rank_4.cpp +++ b/test/batchnorm/batchnorm_fwd_rank_4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/batchnorm/batchnorm_infer_rank_4.cpp b/test/batchnorm/batchnorm_infer_rank_4.cpp index 77fc1daae6130e415e88fc6f69966a82280acd3a..ecb4043b361b058ea82b91d98232eadb694ead33 100644 --- a/test/batchnorm/batchnorm_infer_rank_4.cpp +++ b/test/batchnorm/batchnorm_infer_rank_4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/block_to_ctile_map/test_block_to_ctile_map.cpp b/test/block_to_ctile_map/test_block_to_ctile_map.cpp index 55d9b59f489203cf68b5aebc1c5d9642801435c0..b8e349eda16f53d97b9655840f9803ef8e6a7ace 100644 --- a/test/block_to_ctile_map/test_block_to_ctile_map.cpp +++ b/test/block_to_ctile_map/test_block_to_ctile_map.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/contraction/CMakeLists.txt b/test/contraction/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..1f6e0ed34187f1cd66d7bc8a2e8f1d5a16982d42 --- /dev/null +++ b/test/contraction/CMakeLists.txt @@ -0,0 +1,11 @@ +add_gtest_executable(test_contraction test_contraction.cpp) +target_link_libraries(test_contraction PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_gtest_executable(test_contraction_interface test_contraction_interface.cpp) + target_link_libraries(test_contraction_interface PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) + set(target 1) + endif() +endforeach() diff --git a/test/contraction/test_contraction.cpp b/test/contraction/test_contraction.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c86b849235f8c0b9eacc8ec9e3cf22e965ce60cc --- /dev/null +++ b/test/contraction/test_contraction.cpp @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "profiler/profile_contraction_impl.hpp" + +using F32 = float; +using F64 = double; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using Bilinear = ck::tensor_operation::element_wise::Bilinear; +using Scale = ck::tensor_operation::element_wise::Scale; + +struct MemoryParams +{ + std::vector M; + std::vector N; + std::vector K; + std::vector StridesA; + std::vector StridesB; + std::vector StridesC; + std::vector StridesD; +}; + +template +class TestContraction : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CDLayout = std::tuple_element_t<2, Tuple>; + using DataType = std::tuple_element_t<3, Tuple>; + using DTupleDataType = std::tuple_element_t<4, Tuple>; + using CDElementOp = std::tuple_element_t<5, Tuple>; + + std::vector list_of_memory_params = {{{32, 32}, + {32, 32}, + {32, 32}, + {32768, 1024, 32, 1}, + {32768, 1024, 32, 1}, + {32768, 1024, 32, 1}, + {32768, 1024, 32, 1}}, + {{16, 16}, + {32, 32}, + {16, 16}, + {4096, 256, 16, 1}, + {16, 1, 8192, 256}, + {16384, 1024, 32, 1}, + {16384, 1024, 32, 1}}}; + + std::vector init_methods = {0, 1, 2}; + std::unique_ptr p_cd_element_op; + void Run() + { + for(auto& memory_params : list_of_memory_params) + { + for(const ck::index_t init_method : init_methods) + { + bool pass = + ck::profiler::profile_contraction_impl(true /*do_verification*/, + init_method, + false /*do_logs*/, + false /*time_kernel*/, + *p_cd_element_op, + memory_params.M, + memory_params.N, + memory_params.K, + memory_params.StridesA, + memory_params.StridesB, + memory_params.StridesC, + memory_params.StridesD); + EXPECT_TRUE(pass); + } + } + } +}; + +template +class TestContractionScale : public TestContraction +{ +}; + +template +class TestContractionBilinear : public TestContraction +{ +}; + +using BilinearKernelTypes = + ::testing::Types, Bilinear>, + std::tuple, Bilinear>, + std::tuple, Bilinear>, + std::tuple, Bilinear>, + std::tuple, Bilinear>, + std::tuple, Bilinear>, + std::tuple, Bilinear>, + std::tuple, Bilinear>>; + +using ScaleKernelTypes = ::testing::Types, Scale>, + std::tuple, Scale>, + std::tuple, Scale>, + std::tuple, Scale>, + std::tuple, Scale>, + std::tuple, Scale>, + std::tuple, Scale>, + std::tuple, Scale>>; + +TYPED_TEST_SUITE(TestContractionBilinear, BilinearKernelTypes); +TYPED_TEST_SUITE(TestContractionScale, ScaleKernelTypes); + +TYPED_TEST(TestContractionBilinear, bilinear) +{ + this->p_cd_element_op = std::make_unique(1.f, 1.f); + this->Run(); + this->p_cd_element_op = std::make_unique(-0.5f, 0.5f); + this->Run(); +} + +TYPED_TEST(TestContractionScale, scale) +{ + this->p_cd_element_op = std::make_unique(1.f); + this->Run(); + this->p_cd_element_op = std::make_unique(0.5f); + this->Run(); +} diff --git a/test/contraction/test_contraction_interface.cpp b/test/contraction/test_contraction_interface.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c9e720c597b1ca59e8924b250b3c021f51a17a80 --- /dev/null +++ b/test/contraction/test_contraction_interface.cpp @@ -0,0 +1,195 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "gtest/gtest.h" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" + +#include "ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp" + +#include "ck/library/utility/device_memory.hpp" + +using Pass = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +template +using S = ck::Sequence; + +using F32 = float; +using F64 = double; + +template +class ContractionInstanceWrapper +{ + public: + static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + static constexpr ck::index_t NumDim = 2; + // clang-format off + using ContractionDeviceInstance = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, F32, F32, F32, F32, ck::Tuple, F32, Pass, Pass, Bilinear, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, CDEBlockTransferScalarPerVector>; + // clang-format on + + bool isSupported(std::vector& ADims, + std::vector& BDims, + std::vector& DDims, + std::vector& EDims, + std::vector& AStrides, + std::vector& BStrides, + std::vector& DStrides, + std::vector& EStrides) const + { + auto contraction = ContractionDeviceInstance{}; + + auto argument = contraction.MakeArgument(nullptr, + nullptr, + std::array{nullptr}, + nullptr, + ADims, + AStrides, + BDims, + BStrides, + std::array, 1>{DDims}, + std::array, 1>{DStrides}, + EDims, + EStrides, + Pass{}, + Pass{}, + Bilinear{1.f, 1.f}); + return contraction.IsSupportedArgument(argument); + } +}; + +template +class ContractionDeviceOpWrapper +{ + + protected: + using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD, + DataTypeD, + Pass, + Pass, + Bilinear>; + + public: + bool IsSupportedInstance(std::vector& Dims, + std::vector& Strides) const + { + + bool supported = false; + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = + op_ptr->MakeArgumentPointer(nullptr, + nullptr, + std::array{nullptr}, + nullptr, + Dims, + Strides, + Dims, + Strides, + std::array, 1>{Dims}, + std::array, 1>{Strides}, + Dims, + Strides, + Pass{}, + Pass{}, + Bilinear{1.f, 1.f}); + + supported = supported || op_ptr->IsSupportedArgument(argument_ptr.get()); + } + return supported; + } +}; + +TEST(TestContractionInterface, IncorrectNumDims) +{ + std::vector> Dims = {{4, 4}, {4, 4, 4, 4}, {4, 4, 4, 4, 4, 4}}; + std::vector> Strides = {{1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}}; + ContractionDeviceOpWrapper wrapper_1d; + ContractionDeviceOpWrapper wrapper_2d; + ContractionDeviceOpWrapper wrapper_3d; + EXPECT_FALSE(wrapper_1d.IsSupportedInstance(Dims[0], Strides[0])); + EXPECT_TRUE(wrapper_2d.IsSupportedInstance(Dims[1], Strides[1])); + EXPECT_FALSE(wrapper_3d.IsSupportedInstance(Dims[2], Strides[2])); +} + +TEST(TestContractionInterface, IncorrectDataTypes) +{ + std::vector Dims = {4, 4, 4, 4}; + std::vector Strides = {64, 16, 4, 1}; + ContractionDeviceOpWrapper wrapper_1; + ContractionDeviceOpWrapper wrapper_2; + EXPECT_FALSE(wrapper_1.IsSupportedInstance(Dims, Strides)); + EXPECT_FALSE(wrapper_2.IsSupportedInstance(Dims, Strides)); +} + +TEST(TestContractionSupportedArgs, ABMemoryAccess) +{ + std::vector Dims = {4, 4, 4, 4}; + std::vector Strides = {64, 16, 4, 1}; + std::vector StridesM1 = {4, 1, 64, 16}; + std::vector StridesK1 = {64, 16, 4, 1}; + std::vector InvalidStrides = {4, 4, 4, 4}; + // Memory access to A + ContractionInstanceWrapper<1, 2, 4> wrapperA1; + ContractionInstanceWrapper<2, 2, 4> wrapperA2; + EXPECT_FALSE( + wrapperA1.isSupported(Dims, Dims, Dims, Dims, InvalidStrides, Strides, Strides, Strides)); + EXPECT_FALSE( + wrapperA2.isSupported(Dims, Dims, Dims, Dims, InvalidStrides, Strides, Strides, Strides)); + EXPECT_TRUE( + wrapperA1.isSupported(Dims, Dims, Dims, Dims, StridesM1, Strides, Strides, Strides)); + EXPECT_TRUE( + wrapperA2.isSupported(Dims, Dims, Dims, Dims, StridesK1, Strides, Strides, Strides)); + // Memory access to B + ContractionInstanceWrapper<2, 1, 4> wrapperB1; + ContractionInstanceWrapper<2, 2, 4> wrapperB2; + EXPECT_FALSE( + wrapperB1.isSupported(Dims, Dims, Dims, Dims, Strides, InvalidStrides, Strides, Strides)); + EXPECT_FALSE( + wrapperB2.isSupported(Dims, Dims, Dims, Dims, Strides, InvalidStrides, Strides, Strides)); + EXPECT_TRUE( + wrapperB1.isSupported(Dims, Dims, Dims, Dims, Strides, StridesM1, Strides, Strides)); + EXPECT_TRUE( + wrapperB2.isSupported(Dims, Dims, Dims, Dims, Strides, StridesK1, Strides, Strides)); +} + +TEST(TestContractionSupportedArgs, DEMemoryAccess) +{ + std::vector Dims = {4, 4, 4, 4}; + std::vector Strides = {64, 16, 4, 1}; + std::vector InvalidStrides = {64, 16, 1, 4}; + ContractionInstanceWrapper<2, 2, 4> wrapper; + // Memory access to D + EXPECT_FALSE( + wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, InvalidStrides, Strides)); + EXPECT_TRUE(wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, Strides, Strides)); + // Memory access to E + EXPECT_FALSE( + wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, Strides, InvalidStrides)); + EXPECT_TRUE(wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, Strides, Strides)); +} diff --git a/test/conv_util/conv_util.cpp b/test/conv_util/conv_util.cpp index 73797a7169e78cfeabac101ac0c7da366187edbf..6922bbbcc73dca4d0bc9e98ae49e8b504adb860a 100644 --- a/test/conv_util/conv_util.cpp +++ b/test/conv_util/conv_util.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/convnd_bwd_data/CMakeLists.txt b/test/convnd_bwd_data/CMakeLists.txt index 16ca4de8727878b25c75357e4ccff9bbc0517955..f734b46f53ef886e096b77d529f15644b3d7b96d 100644 --- a/test/convnd_bwd_data/CMakeLists.txt +++ b/test/convnd_bwd_data/CMakeLists.txt @@ -1,2 +1,9 @@ -add_gtest_executable(test_convnd_bwd_data convnd_bwd_data.cpp) -target_link_libraries(test_convnd_bwd_data PRIVATE utility device_conv1d_bwd_data_instance device_conv2d_bwd_data_instance device_conv3d_bwd_data_instance) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_gtest_executable(test_convnd_bwd_data convnd_bwd_data.cpp) + target_link_libraries(test_convnd_bwd_data PRIVATE utility device_conv1d_bwd_data_instance device_conv2d_bwd_data_instance device_conv3d_bwd_data_instance) + set(target 1) + endif() +endforeach() \ No newline at end of file diff --git a/test/convnd_bwd_data/convnd_bwd_data.cpp b/test/convnd_bwd_data/convnd_bwd_data.cpp index 70231d42ae5f9eb9a368c643b2cf4041ae5fac4e..9d2b6cf5770144e25b21874e6f4da93c523e76bb 100644 --- a/test/convnd_bwd_data/convnd_bwd_data.cpp +++ b/test/convnd_bwd_data/convnd_bwd_data.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/convnd_fwd/CMakeLists.txt b/test/convnd_fwd/CMakeLists.txt index 97e170d85118b641a8338cde918bf60bda9a3803..745aceffc932389c8a8b2b35f384dcafa6f5691f 100644 --- a/test/convnd_fwd/CMakeLists.txt +++ b/test/convnd_fwd/CMakeLists.txt @@ -1,2 +1,9 @@ -add_gtest_executable(test_convnd_fwd convnd_fwd.cpp) -target_link_libraries(test_convnd_fwd PRIVATE utility device_conv2d_fwd_instance) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_gtest_executable(test_convnd_fwd convnd_fwd.cpp) + target_link_libraries(test_convnd_fwd PRIVATE utility device_conv2d_fwd_instance) + set(target 1) + endif() +endforeach() diff --git a/test/convnd_fwd/convnd_fwd.cpp b/test/convnd_fwd/convnd_fwd.cpp index a1921a9bfbe0f1221ddada95c75835ab9c942292..fe8798ceb8b1086d7bea43008194845a47f6db43 100644 --- a/test/convnd_fwd/convnd_fwd.cpp +++ b/test/convnd_fwd/convnd_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/data_type/CMakeLists.txt b/test/data_type/CMakeLists.txt index 088fbfec719f9e948686637fe9bec280f9926bd2..2b63727f19153dcc493bbb7ce4545b136a68b7db 100644 --- a/test/data_type/CMakeLists.txt +++ b/test/data_type/CMakeLists.txt @@ -2,3 +2,6 @@ if (USE_BITINT_EXTENSION_INT4) add_gtest_executable(test_int4 int4.cpp) target_link_libraries(test_int4 PRIVATE utility) endif() + +add_gtest_executable(test_fp8 fp8.cpp) +target_link_libraries(test_fp8 PRIVATE utility) diff --git a/test/data_type/fp8.cpp b/test/data_type/fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5004fe952775a74cc6b9736d1e34a8675d30187f --- /dev/null +++ b/test/data_type/fp8.cpp @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/utility/data_type.hpp" +#include "ck/utility/type_convert.hpp" + +using ck::f8_convert_sr; +using ck::f8_t; +using ck::half_t; +using ck::type_convert; + +TEST(FP8, NumericLimits) +{ + EXPECT_EQ(ck::NumericLimits::Min(), 0x08); + EXPECT_EQ(ck::NumericLimits::Max(), 0x77); + EXPECT_EQ(ck::NumericLimits::Lowest(), 0xF7); + EXPECT_EQ(ck::NumericLimits::QuietNaN(), 0x80); +} + +TEST(FP8, ConvertFP32Nearest) +{ + // fix the tolerance value + float abs_tol = 1e-6; + // convert 0 float to fp8 and back, check if holds + ASSERT_NEAR(0.0f, type_convert(type_convert(0.0f)), abs_tol); + // convert minimal float to fp8 and back, check if holds + ASSERT_NEAR(std::numeric_limits::min(), + type_convert(type_convert(std::numeric_limits::min())), + abs_tol); + // convert maximal f8_t to float and check if equal to 240.0 + ASSERT_NEAR(240.0f, type_convert(type_convert(240.0f)), abs_tol); + // convert maximal float to fp8 and back, check if clipped to 240.0 + ASSERT_NEAR(240.0f, + type_convert(type_convert(std::numeric_limits::max())), + abs_tol); + // convert inf float to f8_t and check if it is qNan + ASSERT_NEAR(0x80, type_convert(std::numeric_limits::infinity()), abs_tol); + // positive float value to fp8 and back, check if holds + float pos_float = 0.0078125f; + ASSERT_NEAR(pos_float, type_convert(type_convert(pos_float)), abs_tol); + // negative float value to fp8 and back, check if holds + float neg_float = -0.0156250f; + ASSERT_NEAR(neg_float, type_convert(type_convert(neg_float)), abs_tol); +} + +TEST(FP8, ConvertFP32Stochastic) +{ + // fix the tolerance value + float abs_tol = 1e-6; + // convert 0 float to fp8 and back, check if holds + ASSERT_NEAR(0.0f, type_convert(f8_convert_sr(0.0f)), abs_tol); + // convert minimal float to fp8 and back, check if holds + ASSERT_NEAR(std::numeric_limits::min(), + type_convert(f8_convert_sr(std::numeric_limits::min())), + abs_tol); + // convert maximal f8_t to float and check if equal to 240.0 + ASSERT_NEAR(240.0f, type_convert(f8_convert_sr(240.0f)), abs_tol); + // convert maximal float to fp8 and back, check if clipped to 240.0 + ASSERT_NEAR(240.0f, + type_convert(f8_convert_sr(std::numeric_limits::max())), + abs_tol); + // convert inf float to f8_t and check if it is qNan + ASSERT_NEAR(0x80, f8_convert_sr(std::numeric_limits::infinity()), abs_tol); + // positive float value to fp8 and back, check if holds + float pos_float = 0.0078125f; + ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); + // negative float value to fp8 and back, check if holds + float neg_float = -0.0156250f; + ASSERT_NEAR(neg_float, type_convert(f8_convert_sr(neg_float)), abs_tol); +} + +TEST(FP8, ConvertFP16Nearest) +{ + // fix the tolerance value + float abs_tol = 1e-3; + // convert 0 fp16 to fp8 and back, check if holds + ASSERT_NEAR(half_t{0.0}, type_convert(type_convert(half_t{0.0})), abs_tol); + // convert minimal fp16 to fp8 and back, check if holds + ASSERT_NEAR(ck::NumericLimits::Min(), + type_convert(type_convert(ck::NumericLimits::Min())), + abs_tol); + // convert maximal f8_t to fp16 and check if equal to 240.0 + ASSERT_NEAR(half_t{240.0}, type_convert(type_convert(half_t{240.0})), abs_tol); + // convert maximal fp16 to fp8 and back, check if clipped to 240.0 + ASSERT_NEAR(half_t{240.0}, + type_convert(type_convert(ck::NumericLimits::Max())), + abs_tol); + // convert QuietNaN fp16 to f8_t and check if it is QuietNaN + ASSERT_NEAR(0x80, type_convert(ck::NumericLimits::QuietNaN()), abs_tol); + // positive fp16 value to fp8 and back, check if holds + half_t pos_half = half_t{0.0078125}; + ASSERT_NEAR(pos_half, type_convert(type_convert(pos_half)), abs_tol); + // negative fp16 value to fp8 and back, check if holds + half_t neg_half = half_t{-0.0156250}; + ASSERT_NEAR(neg_half, type_convert(type_convert(neg_half)), abs_tol); +} + +TEST(FP8, ConvertFP16Stochastic) +{ + // fix the tolerance value + float abs_tol = 1e-3; + // convert 0 fp16 to fp8 and back, check if holds + ASSERT_NEAR(half_t{0.0}, type_convert(f8_convert_sr(half_t{0.0})), abs_tol); + // convert minimal fp16 to fp8 and back, check if holds + ASSERT_NEAR(ck::NumericLimits::Min(), + type_convert(f8_convert_sr(ck::NumericLimits::Min())), + abs_tol); + // convert maximal f8_t to fp16 and check if equal to 240.0 + ASSERT_NEAR(half_t{240.0}, type_convert(f8_convert_sr(half_t{240.0})), abs_tol); + // convert maximal fp16 to fp8 and back, check if clipped to 240.0 + ASSERT_NEAR(half_t{240.0}, + type_convert(f8_convert_sr(ck::NumericLimits::Max())), + abs_tol); + // convert QuietNaN fp16 to f8_t and check if it is QuietNaN + ASSERT_NEAR(0x80, f8_convert_sr(ck::NumericLimits::QuietNaN()), abs_tol); + // positive fp16 value to fp8 and back, check if holds + half_t pos_half = half_t{0.0078125}; + ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol); + // negative fp16 value to fp8 and back, check if holds + half_t neg_half = half_t{-0.0156250}; + ASSERT_NEAR(neg_half, type_convert(f8_convert_sr(neg_half)), abs_tol); +} diff --git a/test/data_type/int4.cpp b/test/data_type/int4.cpp index 252a450bf96896e677cbc5f402ccfe722192d9bc..07549c1c48bb807c9e50520e173d528bf27b3bc9 100644 --- a/test/data_type/int4.cpp +++ b/test/data_type/int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/elementwise_normalization/test_elementwise_layernorm_fp16.cpp b/test/elementwise_normalization/test_elementwise_layernorm_fp16.cpp index e80995c4f08622432f4e6c3d3c5cf3200fa8c460..d5ce77dc2b956b02439a42eaa341285217780ac8 100644 --- a/test/elementwise_normalization/test_elementwise_layernorm_fp16.cpp +++ b/test/elementwise_normalization/test_elementwise_layernorm_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "profiler/profile_elementwise_layernorm_impl.hpp" diff --git a/test/gemm/gemm_bf16.cpp b/test/gemm/gemm_bf16.cpp index 5290d466323277daa4e61af1a8273b816a790120..cde5c45aeab59453488ea89e3942550531b65b87 100644 --- a/test/gemm/gemm_bf16.cpp +++ b/test/gemm/gemm_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/gemm/gemm_fp16.cpp b/test/gemm/gemm_fp16.cpp index 92e225def29623c7dcfc8590a5f805052dedb54d..cad250c6fb4d9eb935653d79999db3ff77a6932e 100644 --- a/test/gemm/gemm_fp16.cpp +++ b/test/gemm/gemm_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/gemm/gemm_fp32.cpp b/test/gemm/gemm_fp32.cpp index 5d8c4881b621504fe621d056cbcf68bcb04622fc..c35aa77ea7dad4a7bf56e616d3456be932f44d8e 100644 --- a/test/gemm/gemm_fp32.cpp +++ b/test/gemm/gemm_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/gemm/gemm_fp64.cpp b/test/gemm/gemm_fp64.cpp index 85d7f95bf4ad9ea8df19fdf2600c191f7eabcbf2..e67c8ba4f32331619fa5979c0e66226253993af1 100644 --- a/test/gemm/gemm_fp64.cpp +++ b/test/gemm/gemm_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/gemm/gemm_int8.cpp b/test/gemm/gemm_int8.cpp index e73b22ce9c847f704fe55ad31cc83661b1ebab87..6ece05e306e66cb1013a846f3bdfda6ba4a195f4 100644 --- a/test/gemm/gemm_int8.cpp +++ b/test/gemm/gemm_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/gemm/gemm_standalone_xdl_fp16.cpp b/test/gemm/gemm_standalone_xdl_fp16.cpp index 32a243e0f694c0fa1e57178642d08885dd69d6a3..201a49dcd390090c056ec911f70fee9bd61b61af 100644 --- a/test/gemm/gemm_standalone_xdl_fp16.cpp +++ b/test/gemm/gemm_standalone_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_util.hpp" diff --git a/test/gemm/gemm_util.hpp b/test/gemm/gemm_util.hpp index 9057c0af896891107683605575ee17aeb460c6cf..6c46f4ee89539b1e687b79f40803cfafa62528fc 100644 --- a/test/gemm/gemm_util.hpp +++ b/test/gemm/gemm_util.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/test/gemm/instance/gemm_f16_nn_instance.cpp b/test/gemm/instance/gemm_f16_nn_instance.cpp index 4d65c5876cd92a04cf04a99d3f3655e50c586a8d..9016257f131ffa4a130c009256e66b1d5abe709a 100644 --- a/test/gemm/instance/gemm_f16_nn_instance.cpp +++ b/test/gemm/instance/gemm_f16_nn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/gemm/instance/gemm_f16_nn_instance.hpp b/test/gemm/instance/gemm_f16_nn_instance.hpp index 5ae3928dc972f53eb1bb963cd2130e110ff19ae0..e174b99a1d8156540bdc3a0fcd16bb9870e96204 100644 --- a/test/gemm/instance/gemm_f16_nn_instance.hpp +++ b/test/gemm/instance/gemm_f16_nn_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/gemm/instance/gemm_f16_nt_instance.cpp b/test/gemm/instance/gemm_f16_nt_instance.cpp index 431ff1e62e79fc126deac89194996f2db3e2cb01..27103b88d4c25cd8658e8e58f2b0616377ccb99c 100644 --- a/test/gemm/instance/gemm_f16_nt_instance.cpp +++ b/test/gemm/instance/gemm_f16_nt_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/gemm/instance/gemm_f16_nt_instance.hpp b/test/gemm/instance/gemm_f16_nt_instance.hpp index 99f9ffba4562cd03baff678302e0900a1b842e32..c624425e69573e8293bb9b667f35cfee070c2806 100644 --- a/test/gemm/instance/gemm_f16_nt_instance.hpp +++ b/test/gemm/instance/gemm_f16_nt_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/gemm/instance/gemm_f16_tn_instance.cpp b/test/gemm/instance/gemm_f16_tn_instance.cpp index 6f5dbc311ebb81501eedc218b43003c777424be2..5b11f4dad9000672b1c2cac01cdb5f1eb4a3ad34 100644 --- a/test/gemm/instance/gemm_f16_tn_instance.cpp +++ b/test/gemm/instance/gemm_f16_tn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/gemm/instance/gemm_f16_tn_instance.hpp b/test/gemm/instance/gemm_f16_tn_instance.hpp index 62388aeb398dff03c1f997c5ece4d53560492abd..563e10600adde840ac967f5e980362d53e6460e9 100644 --- a/test/gemm/instance/gemm_f16_tn_instance.hpp +++ b/test/gemm/instance/gemm_f16_tn_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/gemm/instance/gemm_f16_tt_instance.cpp b/test/gemm/instance/gemm_f16_tt_instance.cpp index b6ef5b1cd21a5314c83bedfcdef2d9c3d00906b9..9032150f0c5b7449244c23beea13df71d1e91ea5 100644 --- a/test/gemm/instance/gemm_f16_tt_instance.cpp +++ b/test/gemm/instance/gemm_f16_tt_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/gemm/instance/gemm_f16_tt_instance.hpp b/test/gemm/instance/gemm_f16_tt_instance.hpp index 9d75b4e48cbba7905c3eebe2b30e8c663e31cbeb..62914d7ac22cb9a7ece1c321ef1a72368788ee2f 100644 --- a/test/gemm/instance/gemm_f16_tt_instance.hpp +++ b/test/gemm/instance/gemm_f16_tt_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/gemm/instance/gemm_wavelet_f16_tn_instance.cpp b/test/gemm/instance/gemm_wavelet_f16_tn_instance.cpp index 51c014a91a4e35ce16bd59feb16683899e6b4909..983af7ecddfef2c28c2d58d2f4e62bf36dd74309 100644 --- a/test/gemm/instance/gemm_wavelet_f16_tn_instance.cpp +++ b/test/gemm/instance/gemm_wavelet_f16_tn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/gemm/instance/gemm_wavelet_f16_tn_instance.hpp b/test/gemm/instance/gemm_wavelet_f16_tn_instance.hpp index 110fc5f7d9a8e78507085c7bf870d8d733417076..ef269d78ee5e11c8813b2f4787583c34530d7b2d 100644 --- a/test/gemm/instance/gemm_wavelet_f16_tn_instance.hpp +++ b/test/gemm/instance/gemm_wavelet_f16_tn_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/gemm/run_gemm_test.inc b/test/gemm/run_gemm_test.inc index ec27729b3c9f02148cdd5537fe7f12e7421a7a4d..d208bb5a7b95a1775815dfa95f55befa5c401ea9 100644 --- a/test/gemm/run_gemm_test.inc +++ b/test/gemm/run_gemm_test.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. int run_gemm_test() { diff --git a/test/gemm_layernorm/CMakeLists.txt b/test/gemm_layernorm/CMakeLists.txt index c4feb5c564d59aa6d83a1d0b265f073184f96649..56b8d7737adf6db589cf15c4b615775843718029 100644 --- a/test/gemm_layernorm/CMakeLists.txt +++ b/test/gemm_layernorm/CMakeLists.txt @@ -1,7 +1,11 @@ -add_custom_target(test_gemm_layernorm) - -add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16.cpp) - -target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance) - -add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(test_gemm_layernorm) + add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16.cpp) + target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance) + add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16) + set(target 1) + endif() +endforeach() diff --git a/test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp b/test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp index 740c63aa7ee5f7b6188872b11c14ab1e3bea9280..3f059968784cd908a2969221d52ff89a2db1400f 100644 --- a/test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp +++ b/test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "profiler/profile_gemm_add_relu_add_layernorm_impl.hpp" diff --git a/test/gemm_reduce/gemm_reduce_fp16.cpp b/test/gemm_reduce/gemm_reduce_fp16.cpp index 029165ece125abaa99dc1605bfbc2467b02ecc26..35a149f52c9f82634b4f63f837cbb6bbc808870e 100644 --- a/test/gemm_reduce/gemm_reduce_fp16.cpp +++ b/test/gemm_reduce/gemm_reduce_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/gemm_split_k/CMakeLists.txt b/test/gemm_split_k/CMakeLists.txt index 793091e53c8755db82eefe3cbf0af40666f71300..caf30fca595122a4d79935ea3bdbc10d83f364d1 100644 --- a/test/gemm_split_k/CMakeLists.txt +++ b/test/gemm_split_k/CMakeLists.txt @@ -1,3 +1,9 @@ -add_test_executable(test_gemm_split_k gemm_split_k.cpp) -target_link_libraries(test_gemm_split_k PRIVATE utility) -target_link_libraries(test_gemm_split_k PRIVATE device_gemm_splitk_instance) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_gtest_executable(test_gemm_splitk test_gemm_splitk.cpp) + target_link_libraries(test_gemm_splitk PRIVATE utility device_gemm_splitk_instance) + set(target 1) + endif() +endforeach() diff --git a/test/gemm_split_k/gemm_split_k.cpp b/test/gemm_split_k/gemm_split_k.cpp deleted file mode 100644 index 1edb5769c6919817740e376d7c018781bb3930e2..0000000000000000000000000000000000000000 --- a/test/gemm_split_k/gemm_split_k.cpp +++ /dev/null @@ -1,261 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp" - -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" - -#include "ck/library/utility/host_gemm.hpp" - -enum struct GemmMatrixLayout -{ - MK_KN_MN, // 0 - MK_NK_MN, // 1 - KM_KN_MN, // 2 - KM_NK_MN, // 3 -}; - -template -static bool check_out(const Tensor& ref, const Tensor& result) -{ - float max_diff = 1e-6; - - for(std::size_t i = 0; i < ref.mData.size(); ++i) - { - float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); - if(max_diff < diff) - { - return false; - } - } - - return true; -} - -struct gemmArgs -{ - GemmMatrixLayout layout; - int M; - int N; - int K; - int StrideA; - int StrideB; - int StrideC; - int KBatch; -}; - -int test_gemm(const gemmArgs& args) -{ - using Row = ck::tensor_layout::gemm::RowMajor; - using Col = ck::tensor_layout::gemm::ColumnMajor; - - using PassThrough = ck::tensor_operation::element_wise::PassThrough; - - bool a_row_major, b_row_major, c_row_major; - - switch(args.layout) - { - case GemmMatrixLayout::MK_KN_MN: - a_row_major = true; - b_row_major = true; - c_row_major = true; - break; - case GemmMatrixLayout::MK_NK_MN: - a_row_major = true; - b_row_major = false; - c_row_major = true; - break; - case GemmMatrixLayout::KM_KN_MN: - a_row_major = false; - b_row_major = true; - c_row_major = true; - break; - case GemmMatrixLayout::KM_NK_MN: - a_row_major = false; - b_row_major = false; - c_row_major = true; - break; - default: printf("not supported layout"); return 1; - } - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, bool row_major) { - using namespace ck::literals; - - if(row_major) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - Tensor a_m_k(f_host_tensor_descriptor(args.M, args.K, args.StrideA, a_row_major)); - Tensor b_k_n(f_host_tensor_descriptor(args.K, args.N, args.StrideB, b_row_major)); - Tensor c_m_n_host_result( - f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major)); - Tensor c_m_n_device_result( - f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major)); - - // init data - std::size_t num_thread = 1; - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - // set zero to c_device_buf - c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}, num_thread); - - host_gemm_mk_kn_mn(a_m_k, - b_k_n, - c_m_n_host_result, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}); - - DeviceMem a_device_buf(sizeof(float) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(float) * b_k_n.mDesc.GetElementSpaceSize()); - DeviceMem c_device_buf(sizeof(float) * c_m_n_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_m_k.mData.data()); - b_device_buf.ToDevice(b_k_n.mData.data()); - c_device_buf.ToDevice(c_m_n_device_result.mData.data()); - - auto test = [&](auto a_layout, auto b_layout, auto c_layout) { - bool success = false; - - using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitK; - - const auto gemm_ptrs = - ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - for(auto& gemm_ptr : gemm_ptrs) - { - auto argument_ptr = - gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - args.M, - args.N, - args.K, - args.StrideA, - args.StrideB, - args.StrideC, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}, - args.KBatch); - - auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); - - if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) - { - invoker_ptr->Run(argument_ptr.get()); - - c_device_buf.FromDevice(c_m_n_device_result.mData.data()); - - if(!check_out(c_m_n_host_result, c_m_n_device_result)) - { - success = false; - break; - } - success = true; - } - } - - return success; - }; - - bool success = false; - - if(args.layout == GemmMatrixLayout::MK_KN_MN) - { - success = test(Row{}, Row{}, Row{}); - } - else if(args.layout == GemmMatrixLayout::MK_NK_MN) - { - success = test(Row{}, Col{}, Row{}); - } - else if(args.layout == GemmMatrixLayout::KM_KN_MN) - { - success = test(Col{}, Row{}, Row{}); - } - else - { - success = test(Col{}, Col{}, Row{}); - } - - auto error_code = 0; - if(success) - { - std::cout << "test split k : Pass" << std::endl; - } - else - { - std::cout << "test split k: Fail " << std::endl; - error_code = -1; // test needs to report failure - } - return error_code; -} - -int main(int argc, char* argv[]) -{ - std::vector test_cases; - if(argc == 1) - { - test_cases = {{GemmMatrixLayout::MK_KN_MN, 1024, 1024, 1024, 1024, 1024, 1024, 2}, - {GemmMatrixLayout::MK_KN_MN, 1024, 1024, 1024, 1024, 1024, 1024, 8}}; - } - else if(argc == 9) - { - const auto layout = static_cast(std::stoi(argv[1])); - - const int M = std::stoi(argv[2]); - const int N = std::stoi(argv[3]); - const int K = std::stoi(argv[4]); - - const int StrideA = std::stoi(argv[5]); - const int StrideB = std::stoi(argv[6]); - const int StrideC = std::stoi(argv[7]); - const int KBatch = std::stoi(argv[8]); - test_cases = {{layout, M, N, K, StrideA, StrideB, StrideC, KBatch}}; - } - else - { - printf("arg1: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); - printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); - printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); - printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); - printf("arg2 to 7: M, N, K, StrideA, StrideB, StrideC KBatch\n"); - return -1; - } - bool error = false; - for(const auto& kinder : test_cases) - { - error |= test_gemm(kinder); - } - return error ? 1 : 0; -} diff --git a/test/gemm_split_k/test_gemm_splitk.cpp b/test/gemm_split_k/test_gemm_splitk.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9eba5bba37551f39ffe921406ba871cff8dab4aa --- /dev/null +++ b/test/gemm_split_k/test_gemm_splitk.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_splitk_util.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmSplitK_MK_KN + : public ck::test::TestGemmSplitK, Tuple>::type> +{ +}; + +template +class TestGemmSplitK_MK_NK + : public ck::test::TestGemmSplitK, Tuple>::type> +{ +}; + +template +class TestGemmSplitK_KM_KN + : public ck::test::TestGemmSplitK, Tuple>::type> +{ +}; + +template +class TestGemmSplitK_KM_NK + : public ck::test::TestGemmSplitK, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes = ::testing::Types< + // ADataType, BDataType, CDataType + std::tuple< F16, F16, F16>, + std::tuple< F32, F32, F32> + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmSplitK_MK_KN, KernelTypes); +TYPED_TEST_SUITE(TestGemmSplitK_MK_NK, KernelTypes); +TYPED_TEST_SUITE(TestGemmSplitK_KM_KN, KernelTypes); +TYPED_TEST_SUITE(TestGemmSplitK_KM_NK, KernelTypes); + +#include "test_gemm_splitk_ut_cases.inc" diff --git a/test/gemm_split_k/test_gemm_splitk_ut_cases.inc b/test/gemm_split_k/test_gemm_splitk_ut_cases.inc new file mode 100644 index 0000000000000000000000000000000000000000..54b9c6c9e314870446b75b9946cd2980ab008457 --- /dev/null +++ b/test/gemm_split_k/test_gemm_splitk_ut_cases.inc @@ -0,0 +1,217 @@ +#pragma once + +TYPED_TEST(TestGemmSplitK_MK_KN, SmallM) +{ + std::vector Ms{0, 1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_MK_NK, SmallM) +{ + std::vector Ms{0, 1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_KM_KN, SmallM) +{ + std::vector Ms{0, 1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, M, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_KM_NK, SmallM) +{ + std::vector Ms{0, 1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, M, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_MK_KN, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_MK_NK, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_KM_KN, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, M, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_KM_NK, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, M, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_MK_KN, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_MK_NK, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_KM_KN, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, M, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_KM_NK, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, M, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_MK_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_MK_NK, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_KM_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, M, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_KM_NK, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, M, StrideB, StrideC); +} diff --git a/test/gemm_split_k/test_gemm_splitk_util.hpp b/test/gemm_split_k/test_gemm_splitk_util.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8243747a694141b373cd6d6896dd17213fd3192c --- /dev/null +++ b/test/gemm_split_k/test_gemm_splitk_util.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "include/ck/utility/data_type.hpp" +#include "profiler/profile_gemm_splitk_impl.hpp" + +namespace ck { +namespace test { + +template +class TestGemmSplitK : public testing::Test +{ + using Row = ck::tensor_layout::gemm::RowMajor; + using F32 = float; + + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = Row; + using ADataType = std::tuple_element_t<2, Tuple>; + using BDataType = std::tuple_element_t<3, Tuple>; + using CDataType = std::tuple_element_t<4, Tuple>; + + public: + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; // decimal value initialization + static constexpr bool log_ = false; + static constexpr bool bench_ = false; // measure kernel performance + std::vector k_batches_; + + void SetUp() override { k_batches_ = {1, 2, 3, 5, 8}; } + + void Run(const int M, + const int N, + const int K, + const int StrideA, + const int StrideB, + const int StrideC) + { + for(auto kb : k_batches_) + { + RunSingle(M, N, K, StrideA, StrideB, StrideC, kb); + } + } + + void RunSingle(const int M, + const int N, + const int K, + const int StrideA, + const int StrideB, + const int StrideC, + int kbatch = 1) + { + bool pass = ck::profiler::profile_gemm_splitk_impl( + verify_, init_method_, log_, bench_, M, N, K, StrideA, StrideB, StrideC, kbatch); + EXPECT_TRUE(pass); + } +}; + +} // namespace test +} // namespace ck diff --git a/test/grouped_convnd_bwd_data/CMakeLists.txt b/test/grouped_convnd_bwd_data/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d7c9f465548f5e6756ed043bf28fb322dc71753d --- /dev/null +++ b/test/grouped_convnd_bwd_data/CMakeLists.txt @@ -0,0 +1,6 @@ +if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") + add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data.cpp) + target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance) + add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface.cpp) + target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) +endif() \ No newline at end of file diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp new file mode 100644 index 0000000000000000000000000000000000000000..94808669a321d1ebd766dd62f05d18ba8ccbf47b --- /dev/null +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include + +#include "profiler/profile_grouped_conv_bwd_data_impl.hpp" + +template +class TestGroupedConvndBwdData : public ::testing::Test +{ + protected: + using DataType = std::tuple_element_t<0, Tuple>; + using OutLayout = std::tuple_element_t<1, Tuple>; + using WeiLayout = std::tuple_element_t<2, Tuple>; + using InLayout = std::tuple_element_t<3, Tuple>; + + std::vector conv_params; + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + for(auto& param : conv_params) + { + pass = pass && ck::profiler::profile_grouped_conv_bwd_data_impl( + true, // do_verification + 1, // init_method: integer value + false, // do_log + false, // time_kernel + param); + } + EXPECT_TRUE(pass); + } +}; + +using GNHWC = ck::tensor_layout::convolution::GNHWC; +using NHWGC = ck::tensor_layout::convolution::NHWGC; + +using GKYXC = ck::tensor_layout::convolution::GKYXC; + +using GNHWK = ck::tensor_layout::convolution::GNHWK; +using NHWGK = ck::tensor_layout::convolution::NHWGK; + +using KernelTypes = ::testing::Types, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple>; +TYPED_TEST_SUITE(TestGroupedConvndBwdData, KernelTypes); + +TYPED_TEST(TestGroupedConvndBwdData, Test2D) +{ + this->conv_params.clear(); + + this->conv_params.push_back( + {2, 2, 4, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back( + {2, 2, 128, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back( + {2, 2, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->template Run<2>(); +} diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bc592ba665604f93bdcd3b545152376d14d55c88 --- /dev/null +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface.cpp @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" + +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +#include + +using DataType = ck::half_t; +using AccDataType = float; +using Pass = ck::tensor_operation::element_wise::PassThrough; + +template +using S = ck::Sequence; +using ConvBackwardDataSpecialization = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization; + +static constexpr auto ConvBwdDataDefault = ConvBackwardDataSpecialization::Default; +static constexpr auto Filter1x1Stride1Pad0 = ConvBackwardDataSpecialization::Filter1x1Stride1Pad0; + +template +class TestGroupedConvndBwdData : public ::testing::Test +{ + protected: + static constexpr ck::index_t NDimSpatial = 2; + + using OutLayout = std::tuple_element_t<0, Tuple>; + using WeiLayout = std::tuple_element_t<1, Tuple>; + using InLayout = std::tuple_element_t<2, Tuple>; + + // clang-format off + using GroupedConvBwdDataDeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 + // ######| NDimSpatial| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < NDimSpatial, OutLayout, WeiLayout, ck::Tuple<>, InLayout, DataType, DataType, AccDataType, DataType, ck::Tuple<>, DataType, Pass, Pass, Pass, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>; + // clang-format on + + ck::utils::conv::ConvParam conv_param; + + template + bool Run() + { + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + std::array out_lengths{}; + std::array out_strides{}; + std::array wei_lengths{}; + std::array wei_strides{}; + std::array in_lengths{}; + std::array in_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(out_g_n_k_wos_desc.GetLengths(), out_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), out_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), wei_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), wei_strides); + copy(in_g_n_c_wis_desc.GetLengths(), in_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), in_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + auto conv = GroupedConvBwdDataDeviceInstance{}; + + auto argument = conv.MakeArgument(nullptr, + nullptr, + std::array{}, + nullptr, + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + Pass{}, + Pass{}, + Pass{}); + return conv.IsSupportedArgument(argument); + } +}; + +using GNHWC = ck::tensor_layout::convolution::GNHWC; +using NHWGC = ck::tensor_layout::convolution::NHWGC; + +using GKYXC = ck::tensor_layout::convolution::GKYXC; + +using GNHWK = ck::tensor_layout::convolution::GNHWK; +using NHWGK = ck::tensor_layout::convolution::NHWGK; + +using KernelTypes = + ::testing::Types, std::tuple>; + +template +class TestGroupedConvndBwdDataDefault : public TestGroupedConvndBwdData +{ +}; + +template +class TestGroupedConvndBwdDataFilter1x1 + : public TestGroupedConvndBwdData +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndBwdDataDefault, KernelTypes); +TYPED_TEST_SUITE(TestGroupedConvndBwdDataFilter1x1, KernelTypes); + +TYPED_TEST(TestGroupedConvndBwdDataFilter1x1, SpecializationCheck) +{ + // Check filter 3,3 instead of 1,1 + this->conv_param = {2, 2, 4, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + bool is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); + + // Check strides 2,2 instead of 1,1 + this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; + is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); + + // Check with pad + this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}; + is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); + + // Supported version + this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + is_supported = this->template Run<2>(); + EXPECT_TRUE(is_supported); +} + +TYPED_TEST(TestGroupedConvndBwdDataDefault, VectorLoadCheck) +{ + // vector load for A + this->conv_param = {2, 2, 128, 129, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; + bool is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); + // vector load for B, E, Ds + this->conv_param = {2, 2, 128, 128, 257, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; + is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); +} diff --git a/test/grouped_convnd_bwd_weight/CMakeLists.txt b/test/grouped_convnd_bwd_weight/CMakeLists.txt index e2f0790c8b6b63381454ebd9db368e985714e659..8be872bc690cee34b70aa921e79a7b71e113ac1c 100644 --- a/test/grouped_convnd_bwd_weight/CMakeLists.txt +++ b/test/grouped_convnd_bwd_weight/CMakeLists.txt @@ -1,2 +1,9 @@ -add_gtest_executable(test_grouped_convnd_bwd_weight grouped_convnd_bwd_weight.cpp) -target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_gtest_executable(test_grouped_convnd_bwd_weight grouped_convnd_bwd_weight.cpp) + target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance) + set(target 1) + endif() +endforeach() \ No newline at end of file diff --git a/test/grouped_convnd_bwd_weight/grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/grouped_convnd_bwd_weight.cpp index 6545b6e5662bb971a33e2b0d158a3440636b35fe..207cdab7c7eec32086898bd8f84e8c3d29fb4d79 100644 --- a/test/grouped_convnd_bwd_weight/grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/grouped_convnd_bwd_weight.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -43,7 +43,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test DataType, DataType, DataType>(true, // do_verification - 1, // init_method integer value + 1, // init_method: integer value false, // do_log false, // time_kernel param, @@ -60,9 +60,9 @@ TYPED_TEST_SUITE(TestGroupedConvndBwdWeight, KernelTypes); TYPED_TEST(TestGroupedConvndBwdWeight, Test1D) { this->conv_params.clear(); - this->conv_params.push_back({1, 4, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}}); - this->conv_params.push_back({1, 4, 64, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}}); - this->conv_params.push_back({1, 4, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}}); + this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}}); + this->conv_params.push_back({1, 2, 32, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}}); + this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}}); this->template Run<1>(); } @@ -70,11 +70,11 @@ TYPED_TEST(TestGroupedConvndBwdWeight, Test2D) { this->conv_params.clear(); this->conv_params.push_back( - {2, 4, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); + {2, 2, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back( - {2, 4, 8, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + {2, 2, 4, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back( - {2, 4, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); this->template Run<2>(); } @@ -82,10 +82,10 @@ TYPED_TEST(TestGroupedConvndBwdWeight, Test3D) { this->conv_params.clear(); this->conv_params.push_back( - {3, 4, 128, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + {3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); this->conv_params.push_back( - {3, 4, 8, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + {3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back( - {3, 4, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); this->template Run<3>(); } diff --git a/test/grouped_convnd_fwd/grouped_convnd_fwd.cpp b/test/grouped_convnd_fwd/grouped_convnd_fwd.cpp index 6df7f9969cb5d802ae28d0a7fdf03221c6c91c4f..4a804ef7f1af5abe4aa040867335043a726189d1 100644 --- a/test/grouped_convnd_fwd/grouped_convnd_fwd.cpp +++ b/test/grouped_convnd_fwd/grouped_convnd_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt index 31a78733d38c14f6f328d1bdea3ea3cbd3341507..8c57b667e27287f51c79dc79b879506f5d4d0ba4 100644 --- a/test/grouped_gemm/CMakeLists.txt +++ b/test/grouped_gemm/CMakeLists.txt @@ -1,3 +1,14 @@ -add_test_executable(test_grouped_gemm_fp16 grouped_gemm_fp16.cpp) -target_link_libraries(test_grouped_gemm_fp16 PRIVATE utility) -target_link_libraries(test_grouped_gemm_fp16 PRIVATE device_grouped_gemm_instance) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(test_grouped_gemm) + add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk.cpp) + add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface.cpp) + target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance) + target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance) + + add_dependencies(test_grouped_gemm test_grouped_gemm_splitk test_grouped_gemm_interface) + set(target 1) + endif() +endforeach() diff --git a/test/grouped_gemm/grouped_gemm_fp16.cpp b/test/grouped_gemm/grouped_gemm_fp16.cpp deleted file mode 100644 index f20d750d361040de1a4e5006d8ed25c256945174..0000000000000000000000000000000000000000 --- a/test/grouped_gemm/grouped_gemm_fp16.cpp +++ /dev/null @@ -1,69 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include - -#include "profiler/profile_grouped_gemm_impl.hpp" - -namespace { - -using ADataType = ck::half_t; -using BDataType = ck::half_t; -using CDataType = ck::half_t; -using AccDataType = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -bool TestGroupedGemm() -{ - - std::mt19937 gen(19391); - std::uniform_int_distribution<> distrib(1, 10); - int group_count = distrib(gen); - - // GEMM shape - std::vector gemm_descs; - std::vector p_a, p_b; - std::vector p_c; - - std::vector Ms, Ns, Ks, StrideAs, StrideBs, StrideCs; - - for(int i = 0; i < group_count; i++) - { - Ms.push_back(256 + 256 * distrib(gen)); - Ns.push_back(256 + 256 * distrib(gen)); - Ks.push_back(128 + 128 * distrib(gen)); - - StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); - StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); - StrideCs.push_back(std::is_same::value ? Ns[i] : Ms[i]); - } - - return ck::profiler::profile_grouped_gemm_impl( - true, 1, false, 1, Ms, Ns, Ks, StrideAs, StrideBs, StrideCs); -} - -} // anonymous namespace - -int main() -{ - bool res = true; - - res = res && TestGroupedGemm(); - res = res && TestGroupedGemm(); - res = res && TestGroupedGemm(); - res = res && TestGroupedGemm(); - - std::cout << "TestGroupedGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - - return res ? 0 : 1; -} diff --git a/test/grouped_gemm/test_grouped_gemm_interface.cpp b/test/grouped_gemm/test_grouped_gemm_interface.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ffa8840fc7d8ba6b8a1c1a39c62edc6edae58064 --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_interface.cpp @@ -0,0 +1,202 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include "gtest/gtest.h" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "test_grouped_gemm_util.hpp" + +class TestGGemmSplitKInterface_MKNKMN : public ::testing::Test +{ + protected: + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + using ALayout = Row; + using BLayout = Col; + using ELayout = Row; + + static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + + template + using GGemmInstance = + ck::test::DeviceGroupedGemmSplitkInstanceWrapper; + + using DefaultGGemmInstance = GGemmInstance; +}; + +TEST_F(TestGGemmSplitKInterface_MKNKMN, TileSize) +{ + std::vector Ms{128, 256, 188, 512}; + constexpr int N = 256; + constexpr int K = 128; + + std::vector Ns(Ms.size(), N); + std::vector Ks(Ms.size(), K); + std::vector StrideAs(Ms.size(), K); + std::vector StrideBs(Ms.size(), K); + std::vector StrideCs(Ms.size(), N); + + // M % MPerBlock + EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); + + Ms = std::vector{256, 128, 128, 512}; + Ns = std::vector{256, 177, 128, 512}; + // N % NPerBlock + EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); +} + +TEST_F(TestGGemmSplitKInterface_MKNKMN, VectorLoadWidth) +{ + static constexpr auto GemmMNKPadding = + ck::tensor_operation::device::GemmSpecialization::MNKPadding; + using PaddedGGemmInstance = GGemmInstance; + + std::vector Ms{128, 256, 256, 512}; + constexpr int N = 256; + constexpr int K = 512; + + std::vector Ns(Ms.size(), N); + std::vector Ks(Ms.size(), K); + std::vector StrideAs(Ms.size(), K); + std::vector StrideBs(Ms.size(), K); + std::vector StrideCs(Ms.size(), N); + + // K % ABlockTransferSrcScalarPerVector + Ks = std::vector{256, 177, 128, 512}; + EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); + + Ks = std::vector{256, 164, 128, 512}; + // K % BBlockTransferSrcScalarPerVector + EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); + + Ks = std::vector(4, 128); + Ns = std::vector{256, 127, 128, 512}; + // N % CBlockTransferScalarPerVector_NWaveNPerXDL + EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); +} + +TEST_F(TestGGemmSplitKInterface_MKNKMN, KLoops) +{ + std::vector Ms{128, 256, 256, 512}; + constexpr int N = 256; + constexpr int K = 128; + constexpr int kbatch = 4; + + std::vector Ns(Ms.size(), N); + std::vector Ks(Ms.size(), K); + std::vector StrideAs(Ms.size(), K); + std::vector StrideBs(Ms.size(), K); + std::vector StrideCs(Ms.size(), N); + + // kloops % 2 + Ks = std::vector{256, 512, 320, 768}; + EXPECT_FALSE( + DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch)); + + // Not all gemms have same value for main_k0_block_loop! + Ks = std::vector{256, 512, 512, 512}; + EXPECT_THROW(DefaultGGemmInstance{}.Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch), + std::runtime_error); +} + +class TestGGemmSplitKInterface_KMKNNM : public ::testing::Test +{ + protected: + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + using ALayout = Col; + using BLayout = Row; + using ELayout = Col; + + static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + + template + using GGemmInstance = + ck::test::DeviceGroupedGemmSplitkInstanceWrapper; + + using DefaultGGemmInstance = GGemmInstance; +}; + +TEST_F(TestGGemmSplitKInterface_KMKNNM, TileSize) +{ + std::vector Ms{128, 256, 188, 512}; + constexpr int N = 256; + constexpr int K = 128; + + std::vector Ns(Ms.size(), N); + std::vector Ks(Ms.size(), K); + std::vector StrideAs(Ms.size(), K); + std::vector StrideBs(Ms.size(), K); + std::vector StrideCs(Ms.size(), N); + + // M % MPerBlock + EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); + + Ms = std::vector{128, 256, 256, 512}; + Ns = std::vector{256, 177, 128, 512}; + // N % NPerBlock + EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); +} + +TEST_F(TestGGemmSplitKInterface_KMKNNM, VectorLoadWidth) +{ + static constexpr auto GemmMNKPadding = + ck::tensor_operation::device::GemmSpecialization::MNKPadding; + using PaddedGGemmInstance = GGemmInstance; + + std::vector Ms{128, 256, 256, 512}; + constexpr int N = 256; + constexpr int K = 512; + + std::vector Ns(Ms.size(), N); + std::vector Ks(Ms.size(), K); + std::vector StrideAs(Ms.size(), K); + std::vector StrideBs(Ms.size(), K); + std::vector StrideCs(Ms.size(), N); + + // M % ABlockTransferSrcScalarPerVector + Ms = std::vector{256, 177, 128, 512}; + EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); + + Ms = std::vector{128, 256, 256, 512}; + Ns = std::vector{256, 164, 128, 512}; + // N % BBlockTransferSrcScalarPerVector + EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); + + Ns = std::vector{128, 256, 256, 512}; + Ms = std::vector{256, 130, 128, 512}; + // M % CBlockTransferScalarPerVector_NWaveNPerXDL + EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); +} diff --git a/test/grouped_gemm/test_grouped_gemm_splitk.cpp b/test/grouped_gemm/test_grouped_gemm_splitk.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d9282fa924943c52d317f1fe89b45c90e54c3afb --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_splitk.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/data_type.hpp" + +#include "gtest/gtest.h" +#include "test_grouped_gemm_util.hpp" + +using F16 = ck::half_t; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using RRR_F16_F16_F16 = ck::test::TestGroupedGemm>; +using RCR_F16_F16_F16 = ck::test::TestGroupedGemm>; + +using RRR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm>; +using RCR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm>; + +const std::vector KBATCH{1, 2, 3, 5, 8}; + +INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_KN, RRR_F16_F16_F16, testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_NK, RCR_F16_F16_F16, testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_KN, + RRR_F16_F16_F16_LargeK, + testing::Values(32, 64)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_NK, + RCR_F16_F16_F16_LargeK, + testing::Values(32, 64)); + +#include "test_grouped_gemm_ut_cases.inc" diff --git a/test/grouped_gemm/test_grouped_gemm_ut_cases.inc b/test/grouped_gemm/test_grouped_gemm_ut_cases.inc new file mode 100644 index 0000000000000000000000000000000000000000..d94d140d97ecfb1ce0d879af747fc2badaba6ecc --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_ut_cases.inc @@ -0,0 +1,180 @@ +#pragma once + +TEST_P(RRR_F16_F16_F16, TinyCases) +{ + const std::vector Ms{0, 1}; + constexpr int N = 768; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), N); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RRR_F16_F16_F16, SmallCases) +{ + const std::vector Ms{2, 1, 3, 4, 5, 0}; + constexpr int N = 768; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), N); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RRR_F16_F16_F16, MidCases) +{ + const std::vector Ms{167, 183, 177, 153, 139, 204}; + constexpr int N = 768; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), N); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RRR_F16_F16_F16, Regular) +{ + const std::vector Ms{64, 128, 256}; + constexpr int N = 768; + constexpr int K = 320; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), N); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RRR_F16_F16_F16, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 136; + constexpr int K = 280; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), N); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RCR_F16_F16_F16, TinyCases) +{ + const std::vector Ms{0, 1}; + constexpr int N = 768; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), K); + const std::vector StrideCs(Ms.size(), N); + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RCR_F16_F16_F16, SmallCases) +{ + const std::vector Ms{2, 1, 3, 4, 5, 0}; + constexpr int N = 768; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), K); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RCR_F16_F16_F16, MidCases) +{ + const std::vector Ms{167, 183, 177, 153, 139, 204}; + constexpr int N = 768; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), K); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RCR_F16_F16_F16, Regular) +{ + const std::vector Ms{32, 64, 128, 256}; + constexpr int N = 768; + constexpr int K = 320; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), K); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RCR_F16_F16_F16, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 136; + constexpr int K = 280; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), K); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch) +{ + const std::vector Ms{188, 210}; + constexpr int N = 768; + constexpr int K = 4096; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), N); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RCR_F16_F16_F16_LargeK, TestLargeKBatch) +{ + const std::vector Ms{188, 210}; + constexpr int N = 768; + constexpr int K = 4096; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), K); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} diff --git a/test/grouped_gemm/test_grouped_gemm_util.hpp b/test/grouped_gemm/test_grouped_gemm_util.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b61118b5120e577defe9453c741875cafe33d168 --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_util.hpp @@ -0,0 +1,249 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/stream_config.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/number.hpp" +#include "profiler/profile_grouped_gemm_impl.hpp" + +namespace ck { +namespace test { + +template +std::string serialize_range(const Range& range) +{ + std::stringstream ss; + for(auto& r : range) + { + ss << r << ", "; + } + std::string str = ss.str(); + return std::string(str.begin(), str.end() - 2); +} + +template +class TestGroupedGemm : public testing::TestWithParam +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using ELayout = std::tuple_element_t<2, Tuple>; + using ADataType = std::tuple_element_t<3, Tuple>; + using BDataType = std::tuple_element_t<4, Tuple>; + using EDataType = std::tuple_element_t<5, Tuple>; + + public: + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; // decimal value initialization + static constexpr bool log_ = false; + static constexpr bool bench_ = false; // measure kernel performance + + void SetUp() override {} + + void Run(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + int kbatch = 1) + { + bool pass = ck::profiler::profile_grouped_gemm_impl( + verify_, init_method_, log_, bench_, Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch); + EXPECT_TRUE(pass); + } +}; + +template +struct DeviceGroupedGemmSplitkInstanceWrapper +{ + using F16 = half_t; + using F32 = float; + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + using PassThrough = tensor_operation::element_wise::PassThrough; + + using EmptyTuple = ck::Tuple<>; + + template + using S = ck::Sequence; + + template + using I = ck::Number; + + using ABlockTransferThreadClusterArrageOrder = + std::conditional_t, S<0, 2, 1, 3>, S<0, 1, 3, 2>>; + using ABlockTransferSrcAccessOrder = + std::conditional_t, S<0, 2, 1, 3>, S<0, 1, 3, 2>>; + using ABlockTransferSrcVectorDim = std::conditional_t, I<3>, I<2>>; + using ABlockTransferDstScalarPerVector_K1 = + std::conditional_t, I<8>, I<2>>; + using ABlockLdsAddExtraM = std::conditional_t, I<1>, I<0>>; + + using BBlockTransferThreadClusterArrageOrder = + std::conditional_t, S<0, 1, 3, 2>, S<0, 2, 1, 3>>; + using BBlockTransferSrcAccessOrder = + std::conditional_t, S<0, 1, 3, 2>, S<0, 2, 1, 3>>; + using BBlockTransferSrcVectorDim = std::conditional_t, I<2>, I<3>>; + using BBlockTransferDstScalarPerVector_K1 = + std::conditional_t, I<2>, I<8>>; + using BBlockLdsAddExtraM = std::conditional_t, I<0>, I<1>>; + + using DeviceGroupedGemmSplitKInstance = + tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle< + ALayout, + BLayout, + EmptyTuple, + ELayout, + F16, + F16, + F32, + F16, + EmptyTuple, + F16, + PassThrough, + PassThrough, + PassThrough, + GemmSpec, + 1, + 128, + 128, + 128, + KPerBlock, + K1, + K1, + 32, + 32, + 4, + 2, + S<1, 4, 32, 1>, + ABlockTransferThreadClusterArrageOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim::value, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1::value, + ABlockLdsAddExtraM::value, + S<1, 4, 32, 1>, + BBlockTransferThreadClusterArrageOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim::value, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1::value, + BBlockLdsAddExtraM::value, + 1, + 1, + S<1, 16, 1, 8>, + CDEBlockTransferScalarPerVector_NPerBlock>; + + bool IsSupported(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + int kbatch = 1) const + { + std::size_t n_groups = Ms.size(); + EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups && + StrideBs.size() == n_groups && StrideCs.size() == n_groups) + << "The number of groups is not consistent!"; + + std::vector gemm_descs; + + for(std::size_t i = 0; i < n_groups; ++i) + { + gemm_descs.push_back(tensor_operation::device::GemmDesc{ + Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); + } + + std::vector p_As(n_groups, nullptr); + std::vector p_Bs(n_groups, nullptr); + std::vector p_Cs(n_groups, nullptr); + auto p_Ds = std::vector>{}; + + auto ggemm_instance = DeviceGroupedGemmSplitKInstance{}; + auto argument = ggemm_instance.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); + if(kbatch > 1) + { + ggemm_instance.SetKBatchSize(argument, kbatch); + } + + return ggemm_instance.IsSupportedArgument(argument); + } + + float Run(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + int kbatch = 1) const + { + std::size_t n_groups = Ms.size(); + EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups && + StrideBs.size() == n_groups && StrideCs.size() == n_groups) + << "The number of groups is not consistent!"; + + std::vector gemm_descs; + + for(std::size_t i = 0; i < n_groups; ++i) + { + gemm_descs.push_back(tensor_operation::device::GemmDesc{ + Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); + } + + std::vector p_As(n_groups, nullptr); + std::vector p_Bs(n_groups, nullptr); + std::vector p_Cs(n_groups, nullptr); + auto p_Ds = std::vector>{}; + + auto ggemm_instance = DeviceGroupedGemmSplitKInstance{}; + auto argument = ggemm_instance.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); + if(kbatch > 1) + { + ggemm_instance.SetKBatchSize(argument, kbatch); + } + + EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument)); + auto invoker = ggemm_instance.MakeInvoker(); + DeviceMem gemm_desc_workspace(ggemm_instance.GetWorkSpaceSize(&argument)); + ggemm_instance.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer()); + return invoker.Run(argument, StreamConfig{nullptr, false}); + } +}; + +} // namespace test +} // namespace ck diff --git a/test/magic_number_division/magic_number_division.cpp b/test/magic_number_division/magic_number_division.cpp index 680fddf1933611dc3088cb673ece0dba9e1f3911..253f21e91f80c96c081108b47b045c3902b36926 100644 --- a/test/magic_number_division/magic_number_division.cpp +++ b/test/magic_number_division/magic_number_division.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/normalization/test_groupnorm_fp16.cpp b/test/normalization/test_groupnorm_fp16.cpp index 60d3b13959fc835e4d42244bfff5437561f3178a..325ea75fe5668897a5b489216db042180b14a8ef 100644 --- a/test/normalization/test_groupnorm_fp16.cpp +++ b/test/normalization/test_groupnorm_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "profiler/profile_groupnorm_impl.hpp" diff --git a/test/normalization/test_groupnorm_fp32.cpp b/test/normalization/test_groupnorm_fp32.cpp index 3542f73a62f05f690f12270b44c92816816743c3..ec88442fc02623767f299de36f8b6c5810569cc6 100644 --- a/test/normalization/test_groupnorm_fp32.cpp +++ b/test/normalization/test_groupnorm_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "profiler/profile_groupnorm_impl.hpp" diff --git a/test/normalization/test_layernorm2d_fp16.cpp b/test/normalization/test_layernorm2d_fp16.cpp index d627cbe7f1187e904ce8bfa452274af509708101..2222740fcceaa4b9dfe22a299dc0cd3e054e5d35 100644 --- a/test/normalization/test_layernorm2d_fp16.cpp +++ b/test/normalization/test_layernorm2d_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "profiler/profile_layernorm_impl.hpp" diff --git a/test/normalization/test_layernorm2d_fp32.cpp b/test/normalization/test_layernorm2d_fp32.cpp index de4133aa8369b08be2fe3c3741081d121ceb59f6..30fbe06c60d8ebfcbdd87b3e27c11c601d81635b 100644 --- a/test/normalization/test_layernorm2d_fp32.cpp +++ b/test/normalization/test_layernorm2d_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "profiler/profile_layernorm_impl.hpp" diff --git a/test/pool_fwd/CMakeLists.txt b/test/pool_fwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..6f59b95f6fce8679a022f2bb57de80983cf261a5 --- /dev/null +++ b/test/pool_fwd/CMakeLists.txt @@ -0,0 +1,16 @@ +add_custom_target(test_pool_fwd) + +add_gtest_executable(test_avg_pool2d_fwd test_avg_pool2d_fwd.cpp) +add_gtest_executable(test_avg_pool3d_fwd test_avg_pool3d_fwd.cpp) +add_gtest_executable(test_max_pool2d_fwd test_max_pool2d_fwd.cpp) +add_gtest_executable(test_max_pool3d_fwd test_max_pool3d_fwd.cpp) + +target_link_libraries(test_avg_pool2d_fwd PRIVATE utility device_pool_fwd_instance) +target_link_libraries(test_avg_pool3d_fwd PRIVATE utility device_pool_fwd_instance) +target_link_libraries(test_max_pool2d_fwd PRIVATE utility device_pool_fwd_instance) +target_link_libraries(test_max_pool3d_fwd PRIVATE utility device_pool_fwd_instance) + +add_dependencies(test_pool_fwd test_avg_pool2d_fwd) +add_dependencies(test_pool_fwd test_avg_pool3d_fwd) +add_dependencies(test_pool_fwd test_max_pool2d_fwd) +add_dependencies(test_pool_fwd test_max_pool3d_fwd) diff --git a/test/pool_fwd/test_avg_pool2d_fwd.cpp b/test/pool_fwd/test_avg_pool2d_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..72749fd6e1563f9b432a61871982bc7f6a2e9661 --- /dev/null +++ b/test/pool_fwd/test_avg_pool2d_fwd.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "profiler/profile_pool2d_fwd_impl.hpp" +#include "test_pool_fwd_common.hpp" + +template +class TestAvgPool2dFwd : public ::testing::Test +{ + protected: + using InDataType = std::tuple_element_t<0, Tuple>; + using OutDataType = std::tuple_element_t<1, Tuple>; + using ComputeDataType = std::tuple_element_t<2, Tuple>; + using IndexDataType = std::tuple_element_t<3, Tuple>; + + std::vector params; + + void Run() + { + for(auto param : params) + { + bool success = + ck::profiler::profile_pool2d_fwd_impl(true, + 2, + false, + false, + param.length_, + param.window_spatial_lengths_, + param.window_strides_, + param.input_left_pads_, + param.input_right_pads_); + EXPECT_TRUE(success); + } + } +}; + +using KernelTypes = + ::testing::Types, std::tuple>; + +TYPED_TEST_SUITE(TestAvgPool2dFwd, KernelTypes); +TYPED_TEST(TestAvgPool2dFwd, Test_Pool) +{ + // length, window_length, window_stride, left_pad, right_pad + this->params = {{{1, 1, 1, 1}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}, + {{2, 16, 64, 64}, {64, 64}, {1, 1}, {0, 0}, {0, 0}}, + {{2, 32, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {1, 1}}}; + + this->Run(); +} diff --git a/test/pool_fwd/test_avg_pool3d_fwd.cpp b/test/pool_fwd/test_avg_pool3d_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..00cc3740fa5cc775d43bf60300b693789c7e2d4d --- /dev/null +++ b/test/pool_fwd/test_avg_pool3d_fwd.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "profiler/profile_pool3d_fwd_impl.hpp" +#include "test_pool_fwd_common.hpp" + +template +class TestAvgPool3dFwd : public ::testing::Test +{ + protected: + using InDataType = std::tuple_element_t<0, Tuple>; + using OutDataType = std::tuple_element_t<1, Tuple>; + using ComputeDataType = std::tuple_element_t<2, Tuple>; + using IndexDataType = std::tuple_element_t<3, Tuple>; + + std::vector params; + + void Run() + { + for(auto param : params) + { + bool success = + ck::profiler::profile_pool3d_fwd_impl(true, + 2, + false, + false, + param.length_, + param.window_spatial_lengths_, + param.window_strides_, + param.input_left_pads_, + param.input_right_pads_); + EXPECT_TRUE(success); + } + } +}; + +using KernelTypes = + ::testing::Types, std::tuple>; + +TYPED_TEST_SUITE(TestAvgPool3dFwd, KernelTypes); +TYPED_TEST(TestAvgPool3dFwd, Test_Pool) +{ + // length, window_length, window_stride, left_pad, right_pad + this->params = {{{1, 1, 1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}, + {{2, 16, 64, 64, 64}, {64, 64, 64}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}, + {{2, 32, 30, 30, 30}, {2, 2, 2}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}}}; + + this->Run(); +} diff --git a/test/pool_fwd/test_max_pool2d_fwd.cpp b/test/pool_fwd/test_max_pool2d_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1cf1314f4682acb528a30013011b1cdb21ea965f --- /dev/null +++ b/test/pool_fwd/test_max_pool2d_fwd.cpp @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "profiler/profile_pool2d_fwd_impl.hpp" +#include "test_pool_fwd_common.hpp" + +template +class TestMaxPool2dFwd : public ::testing::Test +{ + protected: + using InDataType = std::tuple_element_t<0, Tuple>; + using OutDataType = std::tuple_element_t<1, Tuple>; + using ComputeDataType = std::tuple_element_t<2, Tuple>; + using IndexDataType = std::tuple_element_t<3, Tuple>; + + std::vector params; + + void Run() + { + for(auto param : params) + { + // max pool + bool success = + ck::profiler::profile_pool2d_fwd_impl(true, + 2, + false, + false, + param.length_, + param.window_spatial_lengths_, + param.window_strides_, + param.input_left_pads_, + param.input_right_pads_); + EXPECT_TRUE(success); + + // max pool + index + success = ck::profiler::profile_pool2d_fwd_impl(true, + 2, + false, + false, + param.length_, + param.window_spatial_lengths_, + param.window_strides_, + param.input_left_pads_, + param.input_right_pads_); + EXPECT_TRUE(success); + } + } +}; + +using KernelTypes = + ::testing::Types, std::tuple>; + +TYPED_TEST_SUITE(TestMaxPool2dFwd, KernelTypes); +TYPED_TEST(TestMaxPool2dFwd, Test_Pool) +{ + // length, window_length, window_stride, left_pad, right_pad + this->params = {{{1, 1, 1, 1}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}, + {{2, 16, 64, 64}, {64, 64}, {1, 1}, {0, 0}, {0, 0}}, + {{2, 32, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {1, 1}}}; + + this->Run(); +} diff --git a/test/pool_fwd/test_max_pool3d_fwd.cpp b/test/pool_fwd/test_max_pool3d_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0b0de4d90ec57aa57db1d8a173d419f879a1ea87 --- /dev/null +++ b/test/pool_fwd/test_max_pool3d_fwd.cpp @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "profiler/profile_pool3d_fwd_impl.hpp" +#include "test_pool_fwd_common.hpp" + +template +class TestMaxPool3dFwd : public ::testing::Test +{ + protected: + using InDataType = std::tuple_element_t<0, Tuple>; + using OutDataType = std::tuple_element_t<1, Tuple>; + using ComputeDataType = std::tuple_element_t<2, Tuple>; + using IndexDataType = std::tuple_element_t<3, Tuple>; + + std::vector params; + + void Run() + { + for(auto param : params) + { + // max pool + bool success = + ck::profiler::profile_pool3d_fwd_impl(true, + 2, + false, + false, + param.length_, + param.window_spatial_lengths_, + param.window_strides_, + param.input_left_pads_, + param.input_right_pads_); + EXPECT_TRUE(success); + + // max pool + index + success = ck::profiler::profile_pool3d_fwd_impl(true, + 2, + false, + false, + param.length_, + param.window_spatial_lengths_, + param.window_strides_, + param.input_left_pads_, + param.input_right_pads_); + EXPECT_TRUE(success); + } + } +}; + +using KernelTypes = + ::testing::Types, std::tuple>; + +TYPED_TEST_SUITE(TestMaxPool3dFwd, KernelTypes); +TYPED_TEST(TestMaxPool3dFwd, Test_Pool) +{ + // length, window_length, window_stride, left_pad, right_pad + this->params = {{{1, 1, 1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}, + {{2, 16, 64, 64, 64}, {64, 64, 64}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}, + {{2, 32, 30, 30, 30}, {2, 2, 2}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}}}; + + this->Run(); +} diff --git a/test/pool_fwd/test_pool_fwd_common.hpp b/test/pool_fwd/test_pool_fwd_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f018635170d9b3eb2235563fd04faefec38fea11 --- /dev/null +++ b/test/pool_fwd/test_pool_fwd_common.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" + +using F16 = ck::half_t; +using F32 = float; +using I32 = int32_t; +using ck::index_t; + +struct PoolingParam +{ + PoolingParam(const std::vector& length, + const std::vector& window_spatial_lengths, + const std::vector& window_strides, + const std::vector& input_left_pads, + const std::vector& input_right_pads) + : length_(length), + window_spatial_lengths_(window_spatial_lengths), + window_strides_(window_strides), + input_left_pads_(input_left_pads), + input_right_pads_(input_right_pads) + { + } + std::vector length_; + std::vector window_spatial_lengths_; + std::vector window_strides_; + std::vector input_left_pads_; + std::vector input_right_pads_; +}; diff --git a/test/reduce/reduce_no_index.cpp b/test/reduce/reduce_no_index.cpp index 3f4d0676b4da84751dd74cd4e76266059c056d67..1ab452442d0b219820e208420a4c4c04a019e9a1 100644 --- a/test/reduce/reduce_no_index.cpp +++ b/test/reduce/reduce_no_index.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/reduce/reduce_with_index.cpp b/test/reduce/reduce_with_index.cpp index c616a68e741e14119d70e0d1d4ba508e371dea78..0301669c5a7578b7c6a0c6ba160685b13295b492 100644 --- a/test/reduce/reduce_with_index.cpp +++ b/test/reduce/reduce_with_index.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/reference_conv_fwd/reference_conv_fwd.cpp b/test/reference_conv_fwd/reference_conv_fwd.cpp index 1f9ba0064cb3b22c394c75d5119f71a188d11a5a..b3328e4b365268af641a6ecb1a65ded872f94ddd 100644 --- a/test/reference_conv_fwd/reference_conv_fwd.cpp +++ b/test/reference_conv_fwd/reference_conv_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/softmax/test_softmax_interface.cpp b/test/softmax/test_softmax_interface.cpp index 8cac0ba0f52b035dae2fb52f0b0346d3267e5b7b..25f666f0ea99eb12e0fa248fdd573aebed0de7a6 100644 --- a/test/softmax/test_softmax_interface.cpp +++ b/test/softmax/test_softmax_interface.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/softmax/test_softmax_rank3.cpp b/test/softmax/test_softmax_rank3.cpp index 5691ee3f6cd5be1bd8ff1075cb57fcfd86aa8322..43ae11bf1f243e930517e618b167af30e1599e0d 100644 --- a/test/softmax/test_softmax_rank3.cpp +++ b/test/softmax/test_softmax_rank3.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -13,7 +13,6 @@ using I = ck::Number; using F16 = ck::half_t; using F32 = float; -using I8 = int8_t; template class TestSoftmax : public ck::TestSoftmax @@ -24,8 +23,7 @@ class TestSoftmax : public ck::TestSoftmax using KernelTypes = ::testing::Types< // InDataType, AccDataType, OutDataType, Rank std::tuple< F16, F32, F16, I<3>>, - std::tuple< F32, F32, F32, I<3>>, - std::tuple< I8, F32, I8, I<3>> + std::tuple< F32, F32, F32, I<3>> >; // clang-format on diff --git a/test/softmax/test_softmax_rank4.cpp b/test/softmax/test_softmax_rank4.cpp index f0b22df25ebdfaaecbf4548ec2b44fdcd477aae4..5cf96bbaa8553ab404693f555a776d22fa37f9e4 100644 --- a/test/softmax/test_softmax_rank4.cpp +++ b/test/softmax/test_softmax_rank4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -13,7 +13,6 @@ using I = ck::Number; using F16 = ck::half_t; using F32 = float; -using I8 = int8_t; template class TestSoftmax : public ck::TestSoftmax @@ -24,8 +23,7 @@ class TestSoftmax : public ck::TestSoftmax using KernelTypes = ::testing::Types< // InDataType, AccDataType, OutDataType, Rank std::tuple< F16, F32, F16, I<4>>, - std::tuple< F32, F32, F32, I<4>>, - std::tuple< I8, F32, I8, I<4>> + std::tuple< F32, F32, F32, I<4>> >; // clang-format on diff --git a/test/softmax/test_softmax_util.hpp b/test/softmax/test_softmax_util.hpp index 40b300cf9927de3556c586805923de9ab97db466..1409af8453b3a9d722fd5074d311f7066f295c65 100644 --- a/test/softmax/test_softmax_util.hpp +++ b/test/softmax/test_softmax_util.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -61,8 +61,92 @@ class TestSoftmax : public ::testing::Test int init_method = 1; // integer value initialization bool log = false; std::vector strides; // intenionally empty, to get packed layout. - bool pass = ck::profiler::profile_softmax_impl( - verify_, init_method, log, bench_, in_length, strides, reduce_dims, alpha, beta); + bool pass = false; + + if constexpr(Rank == 3) + { + if(reduce_dims.size() == 1) + pass = ck::profiler:: + profile_softmax_impl(verify_, + init_method, + log, + bench_, + in_length, + strides, + reduce_dims, + alpha, + beta); + else if(reduce_dims.size() == 2) + pass = ck::profiler:: + profile_softmax_impl(verify_, + init_method, + log, + bench_, + in_length, + strides, + reduce_dims, + alpha, + beta); + else if(reduce_dims.size() == 3) + pass = ck::profiler:: + profile_softmax_impl(verify_, + init_method, + log, + bench_, + in_length, + strides, + reduce_dims, + alpha, + beta); + } + else if constexpr(Rank == 4) + { + if(reduce_dims.size() == 1) + pass = ck::profiler:: + profile_softmax_impl(verify_, + init_method, + log, + bench_, + in_length, + strides, + reduce_dims, + alpha, + beta); + else if(reduce_dims.size() == 2) + pass = ck::profiler:: + profile_softmax_impl(verify_, + init_method, + log, + bench_, + in_length, + strides, + reduce_dims, + alpha, + beta); + else if(reduce_dims.size() == 3) + pass = ck::profiler:: + profile_softmax_impl(verify_, + init_method, + log, + bench_, + in_length, + strides, + reduce_dims, + alpha, + beta); + else if(reduce_dims.size() == 4) + pass = ck::profiler:: + profile_softmax_impl(verify_, + init_method, + log, + bench_, + in_length, + strides, + reduce_dims, + alpha, + beta); + }; + EXPECT_TRUE(pass); } diff --git a/test/space_filling_curve/space_filling_curve.cpp b/test/space_filling_curve/space_filling_curve.cpp index c7f6759e819512f55c17463ffdb503e55bb8c59e..a192ecb28f1b6c67fa7ec27790349e041c6718fb 100644 --- a/test/space_filling_curve/space_filling_curve.cpp +++ b/test/space_filling_curve/space_filling_curve.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/wmma_op/wmma_op.cpp b/test/wmma_op/wmma_op.cpp index 761c15f1dd8a9d22de14a2998d73a3877b1f3003..47d8c7ed6f35236192cc9c4cfacde2e754560241 100644 --- a/test/wmma_op/wmma_op.cpp +++ b/test/wmma_op/wmma_op.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/wmma_op/wmma_op_util.hpp b/test/wmma_op/wmma_op_util.hpp index c70e6a407de6fce20df64dfa29f9aa19de1d9e30..49782bce6e21d4a250b0d07526abf4e39ac57b9d 100644 --- a/test/wmma_op/wmma_op_util.hpp +++ b/test/wmma_op/wmma_op_util.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once